Merge pull request #31796 from kitaisreal/identifier-resolver

Added Analyzer
This commit is contained in:
Maksim Kita 2022-10-25 12:36:08 +03:00 committed by GitHub
commit 06fe6f3c8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
236 changed files with 30572 additions and 162 deletions

View File

@ -0,0 +1,114 @@
#include <Analyzer/AggregationUtils.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_AGGREGATION;
}
namespace
{
class CollectAggregateFunctionNodesVisitor : public ConstInDepthQueryTreeVisitor<CollectAggregateFunctionNodesVisitor>
{
public:
explicit CollectAggregateFunctionNodesVisitor(QueryTreeNodes * aggregate_function_nodes_)
: aggregate_function_nodes(aggregate_function_nodes_)
{}
explicit CollectAggregateFunctionNodesVisitor(String assert_no_aggregates_place_message_)
: assert_no_aggregates_place_message(std::move(assert_no_aggregates_place_message_))
{}
void visitImpl(const QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || !function_node->isAggregateFunction())
return;
if (!assert_no_aggregates_place_message.empty())
throw Exception(ErrorCodes::ILLEGAL_AGGREGATION,
"Aggregate function {} is found {} in query",
function_node->formatASTForErrorMessage(),
assert_no_aggregates_place_message);
if (aggregate_function_nodes)
aggregate_function_nodes->push_back(node);
}
static bool needChildVisit(const QueryTreeNodePtr &, const QueryTreeNodePtr & child_node)
{
return !(child_node->getNodeType() == QueryTreeNodeType::QUERY || child_node->getNodeType() == QueryTreeNodeType::UNION);
}
private:
String assert_no_aggregates_place_message;
QueryTreeNodes * aggregate_function_nodes = nullptr;
};
}
QueryTreeNodes collectAggregateFunctionNodes(const QueryTreeNodePtr & node)
{
QueryTreeNodes result;
CollectAggregateFunctionNodesVisitor visitor(&result);
visitor.visit(node);
return result;
}
void collectAggregateFunctionNodes(const QueryTreeNodePtr & node, QueryTreeNodes & result)
{
CollectAggregateFunctionNodesVisitor visitor(&result);
visitor.visit(node);
}
void assertNoAggregateFunctionNodes(const QueryTreeNodePtr & node, const String & assert_no_aggregates_place_message)
{
CollectAggregateFunctionNodesVisitor visitor(assert_no_aggregates_place_message);
visitor.visit(node);
}
namespace
{
class ValidateGroupingFunctionNodesVisitor : public ConstInDepthQueryTreeVisitor<ValidateGroupingFunctionNodesVisitor>
{
public:
explicit ValidateGroupingFunctionNodesVisitor(String assert_no_grouping_function_place_message_)
: assert_no_grouping_function_place_message(std::move(assert_no_grouping_function_place_message_))
{}
void visitImpl(const QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (function_node && function_node->getFunctionName() == "grouping")
throw Exception(ErrorCodes::ILLEGAL_AGGREGATION,
"GROUPING function {} is found {} in query",
function_node->formatASTForErrorMessage(),
assert_no_grouping_function_place_message);
}
static bool needChildVisit(const QueryTreeNodePtr &, const QueryTreeNodePtr & child_node)
{
return !(child_node->getNodeType() == QueryTreeNodeType::QUERY || child_node->getNodeType() == QueryTreeNodeType::UNION);
}
private:
String assert_no_grouping_function_place_message;
};
}
void assertNoGroupingFunction(const QueryTreeNodePtr & node, const String & assert_no_grouping_function_place_message)
{
ValidateGroupingFunctionNodesVisitor visitor(assert_no_grouping_function_place_message);
visitor.visit(node);
}
}

View File

@ -0,0 +1,28 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/** Collect aggregate function nodes in node children.
* Do not visit subqueries.
*/
QueryTreeNodes collectAggregateFunctionNodes(const QueryTreeNodePtr & node);
/** Collect aggregate function nodes in node children and add them into result.
* Do not visit subqueries.
*/
void collectAggregateFunctionNodes(const QueryTreeNodePtr & node, QueryTreeNodes & result);
/** Assert that there are no aggregate function nodes in node children.
* Do not visit subqueries.
*/
void assertNoAggregateFunctionNodes(const QueryTreeNodePtr & node, const String & assert_no_aggregates_place_message);
/** Assert that there are no GROUPING functions in node children.
* Do not visit subqueries.
*/
void assertNoGroupingFunction(const QueryTreeNodePtr & node, const String & assert_no_grouping_function_place_message);
}

View File

@ -0,0 +1,71 @@
#include <Analyzer/ArrayJoinNode.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Analyzer/Utils.h>
namespace DB
{
ArrayJoinNode::ArrayJoinNode(QueryTreeNodePtr table_expression_, QueryTreeNodePtr join_expressions_, bool is_left_)
: IQueryTreeNode(children_size)
, is_left(is_left_)
{
children[table_expression_child_index] = std::move(table_expression_);
children[join_expressions_child_index] = std::move(join_expressions_);
}
void ArrayJoinNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "ARRAY_JOIN id: " << format_state.getNodeId(this);
buffer << ", is_left: " << is_left;
buffer << '\n' << std::string(indent + 2, ' ') << "TABLE EXPRESSION\n";
getTableExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
buffer << '\n' << std::string(indent + 2, ' ') << "JOIN EXPRESSIONS\n";
getJoinExpressionsNode()->dumpTreeImpl(buffer, format_state, indent + 4);
}
bool ArrayJoinNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const ArrayJoinNode &>(rhs);
return is_left == rhs_typed.is_left;
}
void ArrayJoinNode::updateTreeHashImpl(HashState & state) const
{
state.update(is_left);
}
QueryTreeNodePtr ArrayJoinNode::cloneImpl() const
{
return std::make_shared<ArrayJoinNode>(getTableExpression(), getJoinExpressionsNode(), is_left);
}
ASTPtr ArrayJoinNode::toASTImpl() const
{
auto array_join_ast = std::make_shared<ASTArrayJoin>();
array_join_ast->kind = is_left ? ASTArrayJoin::Kind::Left : ASTArrayJoin::Kind::Inner;
const auto & join_expression_list_node = getJoinExpressionsNode();
array_join_ast->children.push_back(join_expression_list_node->toAST());
array_join_ast->expression_list = array_join_ast->children.back();
ASTPtr tables_in_select_query_ast = std::make_shared<ASTTablesInSelectQuery>();
addTableExpressionOrJoinIntoTablesInSelectQuery(tables_in_select_query_ast, children[table_expression_child_index]);
auto array_join_query_element_ast = std::make_shared<ASTTablesInSelectQueryElement>();
array_join_query_element_ast->children.push_back(std::move(array_join_ast));
array_join_query_element_ast->array_join = array_join_query_element_ast->children.back();
tables_in_select_query_ast->children.push_back(std::move(array_join_query_element_ast));
return tables_in_select_query_ast;
}
}

View File

@ -0,0 +1,113 @@
#pragma once
#include <Storages/IStorage_fwd.h>
#include <Storages/TableLockHolder.h>
#include <Storages/StorageSnapshot.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/StorageID.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
namespace DB
{
/** Array join node represents array join in query tree.
*
* In query tree array join expressions are represented by list query tree node.
*
* Example: SELECT id FROM test_table ARRAY JOIN [1, 2, 3] as a.
*
* Multiple expressions can be inside single array join.
* Example: SELECT id FROM test_table ARRAY JOIN [1, 2, 3] as a, [4, 5, 6] as b.
* Example: SELECT id FROM test_table ARRAY JOIN array_column_1 AS value_1, array_column_2 AS value_2.
*
* Multiple array joins can be inside JOIN TREE.
* Example: SELECT id FROM test_table ARRAY JOIN array_column_1 ARRAY JOIN array_column_2.
*
* Array join can be used inside JOIN TREE with ordinary JOINS.
* Example: SELECT t1.id FROM test_table_1 AS t1 INNER JOIN test_table_2 AS t2 ON t1.id = t2.id ARRAY JOIN [1,2,3];
* Example: SELECT t1.id FROM test_table_1 AS t1 ARRAY JOIN [1,2,3] INNER JOIN test_table_2 AS t2 ON t1.id = t2.id;
*/
class ArrayJoinNode;
using ArrayJoinNodePtr = std::shared_ptr<ArrayJoinNode>;
class ArrayJoinNode final : public IQueryTreeNode
{
public:
/** Construct array join node with table expression.
* Example: SELECT id FROM test_table ARRAY JOIN [1, 2, 3] as a.
* test_table - table expression.
* join_expression_list - list of array join expressions.
*/
ArrayJoinNode(QueryTreeNodePtr table_expression_, QueryTreeNodePtr join_expressions_, bool is_left_);
/// Get table expression
const QueryTreeNodePtr & getTableExpression() const
{
return children[table_expression_child_index];
}
/// Get table expression
QueryTreeNodePtr & getTableExpression()
{
return children[table_expression_child_index];
}
/// Get join expressions
const ListNode & getJoinExpressions() const
{
return children[join_expressions_child_index]->as<const ListNode &>();
}
/// Get join expressions
ListNode & getJoinExpressions()
{
return children[join_expressions_child_index]->as<ListNode &>();
}
/// Get join expressions node
const QueryTreeNodePtr & getJoinExpressionsNode() const
{
return children[join_expressions_child_index];
}
/// Get join expressions node
QueryTreeNodePtr & getJoinExpressionsNode()
{
return children[join_expressions_child_index];
}
/// Returns true if array join is left, false otherwise
bool isLeft() const
{
return is_left;
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::ARRAY_JOIN;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
bool is_left = false;
static constexpr size_t table_expression_child_index = 0;
static constexpr size_t join_expressions_child_index = 1;
static constexpr size_t children_size = join_expressions_child_index + 1;
};
}

View File

@ -0,0 +1,7 @@
if (ENABLE_TESTS)
add_subdirectory(tests)
endif()
if (ENABLE_EXAMPLES)
add_subdirectory(examples)
endif()

View File

@ -0,0 +1,97 @@
#include <Analyzer/ColumnNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTIdentifier.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
ColumnNode::ColumnNode(NameAndTypePair column_, QueryTreeNodePtr expression_node_, QueryTreeNodeWeakPtr column_source_)
: IQueryTreeNode(children_size, weak_pointers_size)
, column(std::move(column_))
{
children[expression_child_index] = std::move(expression_node_);
getSourceWeakPointer() = std::move(column_source_);
}
ColumnNode::ColumnNode(NameAndTypePair column_, QueryTreeNodeWeakPtr column_source_)
: ColumnNode(std::move(column_), nullptr /*expression_node*/, std::move(column_source_))
{
}
QueryTreeNodePtr ColumnNode::getColumnSource() const
{
auto lock = getSourceWeakPointer().lock();
if (!lock)
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Column {} {} query tree node does not have valid source node",
column.name,
column.type->getName());
return lock;
}
QueryTreeNodePtr ColumnNode::getColumnSourceOrNull() const
{
return getSourceWeakPointer().lock();
}
void ColumnNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & state, size_t indent) const
{
buffer << std::string(indent, ' ') << "COLUMN id: " << state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
buffer << ", column_name: " << column.name << ", result_type: " << column.type->getName();
auto column_source_ptr = getSourceWeakPointer().lock();
if (column_source_ptr)
buffer << ", source_id: " << state.getNodeId(column_source_ptr.get());
const auto & expression = getExpression();
if (expression)
{
buffer << '\n' << std::string(indent + 2, ' ') << "EXPRESSION\n";
expression->dumpTreeImpl(buffer, state, indent + 4);
}
}
bool ColumnNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const ColumnNode &>(rhs);
return column == rhs_typed.column;
}
void ColumnNode::updateTreeHashImpl(HashState & hash_state) const
{
hash_state.update(column.name.size());
hash_state.update(column.name);
const auto & column_type_name = column.type->getName();
hash_state.update(column_type_name.size());
hash_state.update(column_type_name);
}
QueryTreeNodePtr ColumnNode::cloneImpl() const
{
return std::make_shared<ColumnNode>(column, getColumnSource());
}
ASTPtr ColumnNode::toASTImpl() const
{
return std::make_shared<ASTIdentifier>(column.name);
}
}

156
src/Analyzer/ColumnNode.h Normal file
View File

@ -0,0 +1,156 @@
#pragma once
#include <Core/NamesAndTypes.h>
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
/** Column node represents column in query tree.
* Column node can have weak pointer to its column source.
* Column source can be table expression, lambda, subquery.
*
* For table ALIAS columns. Column node must contain expression.
* For ARRAY JOIN join expression column. Column node must contain expression.
*
* During query analysis pass identifier node is resolved into column. See IdentifierNode.h.
*
* Examples:
* SELECT id FROM test_table. id is identifier that must be resolved to column node during query analysis pass.
* SELECT lambda(x -> x + 1, [1,2,3]). x is identifier inside lambda that must be resolved to column node during query analysis pass.
*
* Column node is initialized with column name, type and column source weak pointer.
* In case of ALIAS column node is initialized with column name, type, alias expression and column source weak pointer.
*/
class ColumnNode;
using ColumnNodePtr = std::shared_ptr<ColumnNode>;
class ColumnNode final : public IQueryTreeNode
{
public:
/// Construct column node with column name, type, column expression and column source weak pointer
ColumnNode(NameAndTypePair column_, QueryTreeNodePtr expression_node_, QueryTreeNodeWeakPtr column_source_);
/// Construct column node with column name, type and column source weak pointer
ColumnNode(NameAndTypePair column_, QueryTreeNodeWeakPtr column_source_);
/// Get column
const NameAndTypePair & getColumn() const
{
return column;
}
/// Get column name
const String & getColumnName() const
{
return column.name;
}
/// Get column type
const DataTypePtr & getColumnType() const
{
return column.type;
}
/// Set column type
void setColumnType(DataTypePtr column_type)
{
column.type = std::move(column_type);
}
/// Returns true if column node has expression, false otherwise
bool hasExpression() const
{
return children[expression_child_index] != nullptr;
}
/// Get column node expression node
const QueryTreeNodePtr & getExpression() const
{
return children[expression_child_index];
}
/// Get column node expression node
QueryTreeNodePtr & getExpression()
{
return children[expression_child_index];
}
/// Get column node expression node, if there are no expression node exception is thrown
QueryTreeNodePtr & getExpressionOrThrow()
{
if (!children[expression_child_index])
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column expression is not initialized");
return children[expression_child_index];
}
/// Set column node expression node
void setExpression(QueryTreeNodePtr expression_value)
{
children[expression_child_index] = std::move(expression_value);
}
/** Get column source.
* If column source is not valid logical exception is thrown.
*/
QueryTreeNodePtr getColumnSource() const;
/** Get column source.
* If column source is not valid null is returned.
*/
QueryTreeNodePtr getColumnSourceOrNull() const;
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::COLUMN;
}
String getName() const override
{
return column.name;
}
DataTypePtr getResultType() const override
{
return column.type;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
const QueryTreeNodeWeakPtr & getSourceWeakPointer() const
{
return weak_pointers[source_weak_pointer_index];
}
QueryTreeNodeWeakPtr & getSourceWeakPointer()
{
return weak_pointers[source_weak_pointer_index];
}
NameAndTypePair column;
static constexpr size_t expression_child_index = 0;
static constexpr size_t children_size = expression_child_index + 1;
static constexpr size_t source_weak_pointer_index = 0;
static constexpr size_t weak_pointers_size = source_weak_pointer_index + 1;
};
}

View File

@ -0,0 +1,357 @@
#include <Analyzer/ColumnTransformers.h>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTAsterisk.h>
#include <Parsers/ASTColumnsTransformers.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/LambdaNode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/// IColumnTransformerNode implementation
const char * toString(ColumnTransfomerType type)
{
switch (type)
{
case ColumnTransfomerType::APPLY: return "APPLY";
case ColumnTransfomerType::EXCEPT: return "EXCEPT";
case ColumnTransfomerType::REPLACE: return "REPLACE";
}
}
IColumnTransformerNode::IColumnTransformerNode(size_t children_size)
: IQueryTreeNode(children_size)
{}
/// ApplyColumnTransformerNode implementation
const char * toString(ApplyColumnTransformerType type)
{
switch (type)
{
case ApplyColumnTransformerType::LAMBDA: return "LAMBDA";
case ApplyColumnTransformerType::FUNCTION: return "FUNCTION";
}
}
ApplyColumnTransformerNode::ApplyColumnTransformerNode(QueryTreeNodePtr expression_node_)
: IColumnTransformerNode(children_size)
{
if (expression_node_->getNodeType() == QueryTreeNodeType::LAMBDA)
apply_transformer_type = ApplyColumnTransformerType::LAMBDA;
else if (expression_node_->getNodeType() == QueryTreeNodeType::FUNCTION)
apply_transformer_type = ApplyColumnTransformerType::FUNCTION;
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Apply column transformer expression must be lambda or function. Actual {}",
expression_node_->getNodeTypeName());
children[expression_child_index] = std::move(expression_node_);
}
void ApplyColumnTransformerNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "APPLY COLUMN TRANSFORMER id: " << format_state.getNodeId(this);
buffer << ", apply_transformer_type: " << toString(apply_transformer_type);
buffer << '\n' << std::string(indent + 2, ' ') << "EXPRESSION" << '\n';
const auto & expression_node = getExpressionNode();
expression_node->dumpTreeImpl(buffer, format_state, indent + 4);
}
bool ApplyColumnTransformerNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const ApplyColumnTransformerNode &>(rhs);
return apply_transformer_type == rhs_typed.apply_transformer_type;
}
void ApplyColumnTransformerNode::updateTreeHashImpl(IQueryTreeNode::HashState & hash_state) const
{
hash_state.update(static_cast<size_t>(getTransformerType()));
hash_state.update(static_cast<size_t>(getApplyTransformerType()));
}
QueryTreeNodePtr ApplyColumnTransformerNode::cloneImpl() const
{
return std::make_shared<ApplyColumnTransformerNode>(getExpressionNode());
}
ASTPtr ApplyColumnTransformerNode::toASTImpl() const
{
auto ast_apply_transformer = std::make_shared<ASTColumnsApplyTransformer>();
const auto & expression_node = getExpressionNode();
if (apply_transformer_type == ApplyColumnTransformerType::FUNCTION)
{
auto & function_expression = expression_node->as<FunctionNode &>();
ast_apply_transformer->func_name = function_expression.getFunctionName();
ast_apply_transformer->parameters = function_expression.getParametersNode()->toAST();
}
else
{
auto & lambda_expression = expression_node->as<LambdaNode &>();
if (!lambda_expression.getArgumentNames().empty())
ast_apply_transformer->lambda_arg = lambda_expression.getArgumentNames()[0];
ast_apply_transformer->lambda = lambda_expression.toAST();
}
return ast_apply_transformer;
}
/// ExceptColumnTransformerNode implementation
ExceptColumnTransformerNode::ExceptColumnTransformerNode(Names except_column_names_, bool is_strict_)
: IColumnTransformerNode(children_size)
, except_transformer_type(ExceptColumnTransformerType::COLUMN_LIST)
, except_column_names(std::move(except_column_names_))
, is_strict(is_strict_)
{
}
ExceptColumnTransformerNode::ExceptColumnTransformerNode(std::shared_ptr<re2::RE2> column_matcher_)
: IColumnTransformerNode(children_size)
, except_transformer_type(ExceptColumnTransformerType::REGEXP)
, column_matcher(std::move(column_matcher_))
{
}
bool ExceptColumnTransformerNode::isColumnMatching(const std::string & column_name) const
{
if (column_matcher)
return RE2::PartialMatch(column_name, *column_matcher);
for (const auto & name : except_column_names)
if (column_name == name)
return true;
return false;
}
const char * toString(ExceptColumnTransformerType type)
{
switch (type)
{
case ExceptColumnTransformerType::REGEXP:
return "REGEXP";
case ExceptColumnTransformerType::COLUMN_LIST:
return "COLUMN_LIST";
}
}
void ExceptColumnTransformerNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "EXCEPT COLUMN TRANSFORMER id: " << format_state.getNodeId(this);
buffer << ", except_transformer_type: " << toString(except_transformer_type);
if (column_matcher)
{
buffer << ", pattern: " << column_matcher->pattern();
return;
}
else
{
buffer << ", identifiers: ";
size_t except_column_names_size = except_column_names.size();
for (size_t i = 0; i < except_column_names_size; ++i)
{
buffer << except_column_names[i];
if (i + 1 != except_column_names_size)
buffer << ", ";
}
}
}
bool ExceptColumnTransformerNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const ExceptColumnTransformerNode &>(rhs);
if (except_transformer_type != rhs_typed.except_transformer_type ||
is_strict != rhs_typed.is_strict ||
except_column_names != rhs_typed.except_column_names)
return false;
const auto & rhs_column_matcher = rhs_typed.column_matcher;
if (!column_matcher && !rhs_column_matcher)
return true;
else if (column_matcher && !rhs_column_matcher)
return false;
else if (!column_matcher && rhs_column_matcher)
return false;
return column_matcher->pattern() == rhs_column_matcher->pattern();
}
void ExceptColumnTransformerNode::updateTreeHashImpl(IQueryTreeNode::HashState & hash_state) const
{
hash_state.update(static_cast<size_t>(getTransformerType()));
hash_state.update(static_cast<size_t>(getExceptTransformerType()));
hash_state.update(except_column_names.size());
for (const auto & column_name : except_column_names)
{
hash_state.update(column_name.size());
hash_state.update(column_name);
}
if (column_matcher)
{
const auto & pattern = column_matcher->pattern();
hash_state.update(pattern.size());
hash_state.update(pattern);
}
}
QueryTreeNodePtr ExceptColumnTransformerNode::cloneImpl() const
{
if (except_transformer_type == ExceptColumnTransformerType::REGEXP)
return std::make_shared<ExceptColumnTransformerNode>(column_matcher);
return std::make_shared<ExceptColumnTransformerNode>(except_column_names, is_strict);
}
ASTPtr ExceptColumnTransformerNode::toASTImpl() const
{
auto ast_except_transformer = std::make_shared<ASTColumnsExceptTransformer>();
if (column_matcher)
{
ast_except_transformer->setPattern(column_matcher->pattern());
return ast_except_transformer;
}
ast_except_transformer->children.reserve(except_column_names.size());
for (const auto & name : except_column_names)
ast_except_transformer->children.push_back(std::make_shared<ASTIdentifier>(name));
return ast_except_transformer;
}
/// ReplaceColumnTransformerNode implementation
ReplaceColumnTransformerNode::ReplaceColumnTransformerNode(const std::vector<Replacement> & replacements_, bool is_strict_)
: IColumnTransformerNode(children_size)
, is_strict(is_strict_)
{
children[replacements_child_index] = std::make_shared<ListNode>();
auto & replacement_expressions_nodes = getReplacements().getNodes();
std::unordered_set<std::string> replacement_names_set;
for (const auto & replacement : replacements_)
{
auto [_, inserted] = replacement_names_set.emplace(replacement.column_name);
if (!inserted)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expressions in column transformer replace should not contain same replacement {} more than once",
replacement.column_name);
replacements_names.push_back(replacement.column_name);
replacement_expressions_nodes.push_back(replacement.expression_node);
}
}
QueryTreeNodePtr ReplaceColumnTransformerNode::findReplacementExpression(const std::string & expression_name)
{
auto it = std::find(replacements_names.begin(), replacements_names.end(), expression_name);
if (it == replacements_names.end())
return {};
size_t replacement_index = it - replacements_names.begin();
auto & replacement_expressions_nodes = getReplacements().getNodes();
return replacement_expressions_nodes[replacement_index];
}
void ReplaceColumnTransformerNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "REPLACE COLUMN TRANSFORMER id: " << format_state.getNodeId(this);
const auto & replacements_nodes = getReplacements().getNodes();
size_t replacements_size = replacements_nodes.size();
buffer << '\n' << std::string(indent + 2, ' ') << "REPLACEMENTS " << replacements_size << '\n';
for (size_t i = 0; i < replacements_size; ++i)
{
const auto & replacement_name = replacements_names[i];
buffer << std::string(indent + 4, ' ') << "REPLACEMENT NAME " << replacement_name;
buffer << " EXPRESSION" << '\n';
const auto & expression_node = replacements_nodes[i];
expression_node->dumpTreeImpl(buffer, format_state, indent + 6);
if (i + 1 != replacements_size)
buffer << '\n';
}
}
bool ReplaceColumnTransformerNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const ReplaceColumnTransformerNode &>(rhs);
return is_strict == rhs_typed.is_strict && replacements_names == rhs_typed.replacements_names;
}
void ReplaceColumnTransformerNode::updateTreeHashImpl(IQueryTreeNode::HashState & hash_state) const
{
hash_state.update(static_cast<size_t>(getTransformerType()));
const auto & replacement_expressions_nodes = getReplacements().getNodes();
size_t replacements_size = replacement_expressions_nodes.size();
hash_state.update(replacements_size);
for (size_t i = 0; i < replacements_size; ++i)
{
const auto & replacement_name = replacements_names[i];
hash_state.update(replacement_name.size());
hash_state.update(replacement_name);
}
}
QueryTreeNodePtr ReplaceColumnTransformerNode::cloneImpl() const
{
auto result_replace_transformer = std::make_shared<ReplaceColumnTransformerNode>(std::vector<Replacement>{}, false);
result_replace_transformer->is_strict = is_strict;
result_replace_transformer->replacements_names = replacements_names;
return result_replace_transformer;
}
ASTPtr ReplaceColumnTransformerNode::toASTImpl() const
{
auto ast_replace_transformer = std::make_shared<ASTColumnsReplaceTransformer>();
const auto & replacement_expressions_nodes = getReplacements().getNodes();
size_t replacements_size = replacement_expressions_nodes.size();
ast_replace_transformer->children.reserve(replacements_size);
for (size_t i = 0; i < replacements_size; ++i)
{
auto replacement_ast = std::make_shared<ASTColumnsReplaceTransformer::Replacement>();
replacement_ast->name = replacements_names[i];
replacement_ast->expr = replacement_expressions_nodes[i]->toAST();
ast_replace_transformer->children.push_back(replacement_ast);
}
return ast_replace_transformer;
}
}

View File

@ -0,0 +1,316 @@
#pragma once
#include <re2/re2.h>
#include <Analyzer/Identifier.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
namespace DB
{
/** Transformers are query tree nodes that handle additional logic that you can apply after MatcherQueryTreeNode is resolved.
* Check MatcherQueryTreeNode.h before reading this documentation.
*
* They main purpose is to apply some logic for expressions after matcher is resolved.
* There are 3 types of transformers:
*
* 1. APPLY transformer:
* APPLY transformer transform matched expression using lambda or function into another expression.
* It has 2 syntax variants:
* 1. lambda variant: SELECT matcher APPLY (x -> expr(x)).
* 2. function variant: SELECT matcher APPLY function_name(optional_parameters).
*
* 2. EXCEPT transformer:
* EXCEPT transformer discard some columns.
* It has 2 syntax variants:
* 1. regexp variant: SELECT matcher EXCEPT ('regexp').
* 2. column names list variant: SELECT matcher EXCEPT (column_name_1, ...).
*
* 3. REPLACE transformer:
* REPLACE transformer applies similar transformation as APPLY transformer, but only for expressions
* that match replacement expression name.
*
* Example:
* CREATE TABLE test_table (id UInt64) ENGINE=TinyLog;
* SELECT * REPLACE (id + 1 AS id) FROM test_table.
* This query is transformed into SELECT id + 1 FROM test_table.
* It is important that AS id is not alias, it is replacement name. id + 1 is replacement expression.
*
* REPLACE transformer cannot contain multiple replacements with same name.
*
* REPLACE transformer expression does not necessary include replacement column name.
* Example:
* SELECT * REPLACE (1 AS id) FROM test_table.
*
* REPLACE transformer expression does not throw exception if there are no columns to apply replacement.
* Example:
* SELECT * REPLACE (1 AS unknown_column) FROM test_table;
*
* REPLACE transform can contain multiple replacements.
* Example:
* SELECT * REPLACE (1 AS id, 2 AS value).
*
* Matchers can be combined together and chained.
* Example:
* SELECT * EXCEPT (id) APPLY (x -> toString(x)) APPLY (x -> length(x)) FROM test_table.
*/
/// Column transformer type
enum class ColumnTransfomerType
{
APPLY,
EXCEPT,
REPLACE
};
/// Get column transformer type name
const char * toString(ColumnTransfomerType type);
class IColumnTransformerNode;
using ColumnTransformerNodePtr = std::shared_ptr<IColumnTransformerNode>;
using ColumnTransformersNodes = std::vector<ColumnTransformerNodePtr>;
/// IColumnTransformer base interface.
class IColumnTransformerNode : public IQueryTreeNode
{
public:
/// Get transformer type
virtual ColumnTransfomerType getTransformerType() const = 0;
/// Get transformer type name
const char * getTransformerTypeName() const
{
return toString(getTransformerType());
}
QueryTreeNodeType getNodeType() const final
{
return QueryTreeNodeType::TRANSFORMER;
}
protected:
/// Construct column transformer node and resize children to children size
explicit IColumnTransformerNode(size_t children_size);
};
enum class ApplyColumnTransformerType
{
LAMBDA,
FUNCTION
};
/// Get apply column transformer type name
const char * toString(ApplyColumnTransformerType type);
class ApplyColumnTransformerNode;
using ApplyColumnTransformerNodePtr = std::shared_ptr<ApplyColumnTransformerNode>;
/// Apply column transformer
class ApplyColumnTransformerNode final : public IColumnTransformerNode
{
public:
/** Initialize apply column transformer with expression node.
* Expression node must be lambda or function otherwise exception is thrown.
*/
explicit ApplyColumnTransformerNode(QueryTreeNodePtr expression_node_);
/// Get apply transformer type
ApplyColumnTransformerType getApplyTransformerType() const
{
return apply_transformer_type;
}
/// Get apply transformer expression node
const QueryTreeNodePtr & getExpressionNode() const
{
return children[expression_child_index];
}
ColumnTransfomerType getTransformerType() const override
{
return ColumnTransfomerType::APPLY;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(IQueryTreeNode::HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
ApplyColumnTransformerType apply_transformer_type = ApplyColumnTransformerType::LAMBDA;
static constexpr size_t expression_child_index = 0;
static constexpr size_t children_size = expression_child_index + 1;
};
/// Except column transformer type
enum class ExceptColumnTransformerType
{
REGEXP,
COLUMN_LIST,
};
const char * toString(ExceptColumnTransformerType type);
class ExceptColumnTransformerNode;
using ExceptColumnTransformerNodePtr = std::shared_ptr<ExceptColumnTransformerNode>;
/** Except column transformer.
* Strict EXCEPT column transformer must use all column names during matched nodes transformation.
*
* Example:
* CREATE TABLE test_table (id UInt64, value String) ENGINE=TinyLog;
* SELECT * EXCEPT STRICT (id, value1) FROM test_table;
* Such query will throw exception because column with name `value1` was not matched by strict EXCEPT transformer.
*
* Strict is valid only for EXCEPT COLUMN_LIST transformer.
*/
class ExceptColumnTransformerNode final : public IColumnTransformerNode
{
public:
/// Initialize except column transformer with column names
explicit ExceptColumnTransformerNode(Names except_column_names_, bool is_strict_);
/// Initialize except column transformer with regexp column matcher
explicit ExceptColumnTransformerNode(std::shared_ptr<re2::RE2> column_matcher_);
/// Get except transformer type
ExceptColumnTransformerType getExceptTransformerType() const
{
return except_transformer_type;
}
/** Returns true if except column transformer is strict, false otherwise.
* Valid only for EXCEPT COLUMN_LIST transformer.
*/
bool isStrict() const
{
return is_strict;
}
/// Returns true if except transformer match column name, false otherwise.
bool isColumnMatching(const std::string & column_name) const;
/** Get except column names.
* Valid only for column list except transformer.
*/
const Names & getExceptColumnNames() const
{
return except_column_names;
}
ColumnTransfomerType getTransformerType() const override
{
return ColumnTransfomerType::EXCEPT;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(IQueryTreeNode::HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
ExceptColumnTransformerType except_transformer_type;
Names except_column_names;
std::shared_ptr<re2::RE2> column_matcher;
bool is_strict = false;
static constexpr size_t children_size = 0;
};
class ReplaceColumnTransformerNode;
using ReplaceColumnTransformerNodePtr = std::shared_ptr<ReplaceColumnTransformerNode>;
/** Replace column transformer.
* Strict replace column transformer must use all replacements during matched nodes transformation.
*
* Example:
* CREATE TABLE test_table (id UInt64, value String) ENGINE=TinyLog;
* SELECT * REPLACE STRICT (1 AS id, 2 AS value_1) FROM test_table;
* Such query will throw exception because column with name `value1` was not matched by strict REPLACE transformer.
*/
class ReplaceColumnTransformerNode final : public IColumnTransformerNode
{
public:
/// Replacement is column name and replace expression
struct Replacement
{
std::string column_name;
QueryTreeNodePtr expression_node;
};
/// Initialize replace column transformer with replacements
explicit ReplaceColumnTransformerNode(const std::vector<Replacement> & replacements_, bool is_strict);
ColumnTransfomerType getTransformerType() const override
{
return ColumnTransfomerType::REPLACE;
}
/// Get replacements
const ListNode & getReplacements() const
{
return children[replacements_child_index]->as<ListNode &>();
}
/// Get replacements node
const QueryTreeNodePtr & getReplacementsNode() const
{
return children[replacements_child_index];
}
/// Get replacements names
const Names & getReplacementsNames() const
{
return replacements_names;
}
/// Returns true if replace column transformer is strict, false otherwise
bool isStrict() const
{
return is_strict;
}
/** Returns replacement expression if replacement is registered for expression name, null otherwise.
* Returned replacement expression must be cloned by caller.
*/
QueryTreeNodePtr findReplacementExpression(const std::string & expression_name);
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(IQueryTreeNode::HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
ListNode & getReplacements()
{
return children[replacements_child_index]->as<ListNode &>();
}
Names replacements_names;
bool is_strict = false;
static constexpr size_t replacements_child_index = 0;
static constexpr size_t children_size = replacements_child_index + 1;
};
}

View File

@ -0,0 +1,71 @@
#include <Analyzer/ConstantNode.h>
#include <Common/FieldVisitorToString.h>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <DataTypes/FieldToDataType.h>
#include <Parsers/ASTLiteral.h>
#include <Interpreters/convertFieldToType.h>
namespace DB
{
ConstantNode::ConstantNode(ConstantValuePtr constant_value_)
: IQueryTreeNode(children_size)
, constant_value(std::move(constant_value_))
, value_string(applyVisitor(FieldVisitorToString(), constant_value->getValue()))
{
}
ConstantNode::ConstantNode(Field value_, DataTypePtr value_data_type_)
: ConstantNode(std::make_shared<ConstantValue>(convertFieldToTypeOrThrow(value_, *value_data_type_), value_data_type_))
{}
ConstantNode::ConstantNode(Field value_)
: ConstantNode(value_, applyVisitor(FieldToDataType(), value_))
{}
void ConstantNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "CONSTANT id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
buffer << ", constant_value: " << constant_value->getValue().dump();
buffer << ", constant_value_type: " << constant_value->getType()->getName();
}
bool ConstantNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const ConstantNode &>(rhs);
return *constant_value == *rhs_typed.constant_value && value_string == rhs_typed.value_string;
}
void ConstantNode::updateTreeHashImpl(HashState & hash_state) const
{
auto type_name = constant_value->getType()->getName();
hash_state.update(type_name.size());
hash_state.update(type_name);
hash_state.update(value_string.size());
hash_state.update(value_string);
}
QueryTreeNodePtr ConstantNode::cloneImpl() const
{
return std::make_shared<ConstantNode>(constant_value);
}
ASTPtr ConstantNode::toASTImpl() const
{
return std::make_shared<ASTLiteral>(constant_value->getValue());
}
}

View File

@ -0,0 +1,82 @@
#pragma once
#include <Core/Field.h>
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/** Constant node represents constant value in query tree.
* Constant value must be representable by Field.
* Examples: 1, 'constant_string', [1,2,3].
*/
class ConstantNode;
using ConstantNodePtr = std::shared_ptr<ConstantNode>;
class ConstantNode final : public IQueryTreeNode
{
public:
/// Construct constant query tree node from constant value
explicit ConstantNode(ConstantValuePtr constant_value_);
/** Construct constant query tree node from field and data type.
*
* Throws exception if value cannot be converted to value data type.
*/
explicit ConstantNode(Field value_, DataTypePtr value_data_type_);
/// Construct constant query tree node from field, data type will be derived from field value
explicit ConstantNode(Field value_);
/// Get constant value
const Field & getValue() const
{
return constant_value->getValue();
}
/// Get constant value string representation
const String & getValueStringRepresentation() const
{
return value_string;
}
ConstantValuePtr getConstantValueOrNull() const override
{
return constant_value;
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::CONSTANT;
}
String getName() const override
{
return value_string;
}
DataTypePtr getResultType() const override
{
return constant_value->getType();
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
ConstantValuePtr constant_value;
String value_string;
static constexpr size_t children_size = 0;
};
}

View File

@ -0,0 +1,47 @@
#pragma once
#include <Core/Field.h>
#include <DataTypes/IDataType.h>
namespace DB
{
/** Immutable constant value representation during analysis stage.
* Some query nodes can be represented by constant (scalar subqueries, functions with constant arguments).
*/
class ConstantValue;
using ConstantValuePtr = std::shared_ptr<ConstantValue>;
class ConstantValue
{
public:
ConstantValue(Field value_, DataTypePtr data_type_)
: value(std::move(value_))
, data_type(std::move(data_type_))
{}
const Field & getValue() const
{
return value;
}
const DataTypePtr & getType() const
{
return data_type;
}
private:
Field value;
DataTypePtr data_type;
};
inline bool operator==(const ConstantValue & lhs, const ConstantValue & rhs)
{
return lhs.getValue() == rhs.getValue() && lhs.getType()->equals(*rhs.getType());
}
inline bool operator!=(const ConstantValue & lhs, const ConstantValue & rhs)
{
return !(lhs == rhs);
}
}

View File

@ -0,0 +1,215 @@
#include <Analyzer/FunctionNode.h>
#include <Common/SipHash.h>
#include <Common/FieldVisitorToString.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Parsers/ASTFunction.h>
#include <Functions/IFunction.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Analyzer/IdentifierNode.h>
namespace DB
{
FunctionNode::FunctionNode(String function_name_)
: IQueryTreeNode(children_size)
, function_name(function_name_)
{
children[parameters_child_index] = std::make_shared<ListNode>();
children[arguments_child_index] = std::make_shared<ListNode>();
}
void FunctionNode::resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value)
{
aggregate_function = nullptr;
function = std::move(function_value);
result_type = std::move(result_type_value);
function_name = function->getName();
}
void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value)
{
function = nullptr;
aggregate_function = std::move(aggregate_function_value);
result_type = std::move(result_type_value);
function_name = aggregate_function->getName();
}
void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value)
{
resolveAsAggregateFunction(window_function_value, result_type_value);
}
void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "FUNCTION id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
buffer << ", function_name: " << function_name;
std::string function_type = "ordinary";
if (isAggregateFunction())
function_type = "aggregate";
else if (isWindowFunction())
function_type = "window";
buffer << ", function_type: " << function_type;
if (result_type)
buffer << ", result_type: " + result_type->getName();
if (constant_value)
{
buffer << ", constant_value: " << constant_value->getValue().dump();
buffer << ", constant_value_type: " << constant_value->getType()->getName();
}
const auto & parameters = getParameters();
if (!parameters.getNodes().empty())
{
buffer << '\n' << std::string(indent + 2, ' ') << "PARAMETERS\n";
parameters.dumpTreeImpl(buffer, format_state, indent + 4);
}
const auto & arguments = getArguments();
if (!arguments.getNodes().empty())
{
buffer << '\n' << std::string(indent + 2, ' ') << "ARGUMENTS\n";
arguments.dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasWindow())
{
buffer << '\n' << std::string(indent + 2, ' ') << "WINDOW\n";
getWindowNode()->dumpTreeImpl(buffer, format_state, indent + 4);
}
}
String FunctionNode::getName() const
{
String name = function_name;
const auto & parameters = getParameters();
const auto & parameters_nodes = parameters.getNodes();
if (!parameters_nodes.empty())
{
name += '(';
name += parameters.getName();
name += ')';
}
const auto & arguments = getArguments();
name += '(';
name += arguments.getName();
name += ')';
return name;
}
bool FunctionNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const FunctionNode &>(rhs);
if (function_name != rhs_typed.function_name ||
isAggregateFunction() != rhs_typed.isAggregateFunction() ||
isOrdinaryFunction() != rhs_typed.isOrdinaryFunction() ||
isWindowFunction() != rhs_typed.isWindowFunction())
return false;
if (result_type && rhs_typed.result_type && !result_type->equals(*rhs_typed.getResultType()))
return false;
else if (result_type && !rhs_typed.result_type)
return false;
else if (!result_type && rhs_typed.result_type)
return false;
if (constant_value && rhs_typed.constant_value && *constant_value != *rhs_typed.constant_value)
return false;
else if (constant_value && !rhs_typed.constant_value)
return false;
else if (!constant_value && rhs_typed.constant_value)
return false;
return true;
}
void FunctionNode::updateTreeHashImpl(HashState & hash_state) const
{
hash_state.update(function_name.size());
hash_state.update(function_name);
hash_state.update(isOrdinaryFunction());
hash_state.update(isAggregateFunction());
hash_state.update(isWindowFunction());
if (result_type)
{
auto result_type_name = result_type->getName();
hash_state.update(result_type_name.size());
hash_state.update(result_type_name);
}
if (constant_value)
{
auto constant_dump = applyVisitor(FieldVisitorToString(), constant_value->getValue());
hash_state.update(constant_dump.size());
hash_state.update(constant_dump);
auto constant_value_type_name = constant_value->getType()->getName();
hash_state.update(constant_value_type_name.size());
hash_state.update(constant_value_type_name);
}
}
QueryTreeNodePtr FunctionNode::cloneImpl() const
{
auto result_function = std::make_shared<FunctionNode>(function_name);
/** This is valid for clone method to reuse same function pointers
* because ordinary functions or aggregate functions must be stateless.
*/
result_function->function = function;
result_function->aggregate_function = aggregate_function;
result_function->result_type = result_type;
result_function->constant_value = constant_value;
return result_function;
}
ASTPtr FunctionNode::toASTImpl() const
{
auto function_ast = std::make_shared<ASTFunction>();
function_ast->name = function_name;
function_ast->is_window_function = isWindowFunction();
const auto & parameters = getParameters();
if (!parameters.getNodes().empty())
{
function_ast->children.push_back(parameters.toAST());
function_ast->parameters = function_ast->children.back();
}
const auto & arguments = getArguments();
function_ast->children.push_back(arguments.toAST());
function_ast->arguments = function_ast->children.back();
auto window_node = getWindowNode();
if (window_node)
{
if (auto * identifier_node = window_node->as<IdentifierNode>())
function_ast->window_name = identifier_node->getIdentifier().getFullName();
else
function_ast->window_definition = window_node->toAST();
}
return function_ast;
}
}

232
src/Analyzer/FunctionNode.h Normal file
View File

@ -0,0 +1,232 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Analyzer/ConstantValue.h>
namespace DB
{
class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
class IAggregateFunction;
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
/** Function node represents function in query tree.
* Function syntax: function_name(parameter_1, ...)(argument_1, ...).
* If function does not have parameters its syntax is function_name(argument_1, ...).
* If function does not have arguments its syntax is function_name().
*
* In query tree function parameters and arguments are represented by ListNode.
*
* Function can be:
* 1. Aggregate function. Example: quantile(0.5)(x), sum(x).
* 2. Non aggregate function. Example: plus(x, x).
* 3. Window function. Example: sum(x) OVER (PARTITION BY expr ORDER BY expr).
*
* Initially function node is initialized with function name.
* For window function client must initialize function window node.
*
* During query analysis pass function must be resolved using `resolveAsFunction`, `resolveAsAggregateFunction`, `resolveAsWindowFunction` methods.
* Resolved function is function that has result type and is initialized with concrete aggregate or non aggregate function.
*/
class FunctionNode;
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
class FunctionNode final : public IQueryTreeNode
{
public:
/** Initialize function node with function name.
* Later during query analysis pass function must be resolved.
*/
explicit FunctionNode(String function_name_);
/// Get function name
const String & getFunctionName() const
{
return function_name;
}
/// Get parameters
const ListNode & getParameters() const
{
return children[parameters_child_index]->as<const ListNode &>();
}
/// Get parameters
ListNode & getParameters()
{
return children[parameters_child_index]->as<ListNode &>();
}
/// Get parameters node
const QueryTreeNodePtr & getParametersNode() const
{
return children[parameters_child_index];
}
/// Get parameters node
QueryTreeNodePtr & getParametersNode()
{
return children[parameters_child_index];
}
/// Get arguments
const ListNode & getArguments() const
{
return children[arguments_child_index]->as<const ListNode &>();
}
/// Get arguments
ListNode & getArguments()
{
return children[arguments_child_index]->as<ListNode &>();
}
/// Get arguments node
const QueryTreeNodePtr & getArgumentsNode() const
{
return children[arguments_child_index];
}
/// Get arguments node
QueryTreeNodePtr & getArgumentsNode()
{
return children[arguments_child_index];
}
/// Returns true if function node has window, false otherwise
bool hasWindow() const
{
return children[window_child_index] != nullptr;
}
/** Get window node.
* Valid only for window function node.
* Result window node can be identifier node or window node.
* 1. It can be identifier node if window function is defined as expr OVER window_name.
* 2. It can be window node if window function is defined as expr OVER (window_name ...).
*/
const QueryTreeNodePtr & getWindowNode() const
{
return children[window_child_index];
}
/** Get window node.
* Valid only for window function node.
*/
QueryTreeNodePtr & getWindowNode()
{
return children[window_child_index];
}
/** Get non aggregate function.
* If function is not resolved nullptr returned.
*/
const FunctionOverloadResolverPtr & getFunction() const
{
return function;
}
/** Get aggregate function.
* If function is not resolved nullptr returned.
* If function is resolved as non aggregate function nullptr returned.
*/
const AggregateFunctionPtr & getAggregateFunction() const
{
return aggregate_function;
}
/// Is function node resolved
bool isResolved() const
{
return result_type != nullptr && (function != nullptr || aggregate_function != nullptr);
}
/// Is function node window function
bool isWindowFunction() const
{
return getWindowNode() != nullptr;
}
/// Is function node aggregate function
bool isAggregateFunction() const
{
return aggregate_function != nullptr && !isWindowFunction();
}
/// Is function node ordinary function
bool isOrdinaryFunction() const
{
return function != nullptr;
}
/** Resolve function node as non aggregate function.
* It is important that function name is updated with resolved function name.
* Main motivation for this is query tree optimizations.
* Assume we have `multiIf` function with single condition, it can be converted to `if` function.
* Function name must be updated accordingly.
*/
void resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value);
/** Resolve function node as aggregate function.
* It is important that function name is updated with resolved function name.
* Main motivation for this is query tree optimizations.
*/
void resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value);
/** Resolve function node as window function.
* It is important that function name is updated with resolved function name.
* Main motivation for this is query tree optimizations.
*/
void resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value);
/// Perform constant folding for function node
void performConstantFolding(ConstantValuePtr constant_folded_value)
{
constant_value = std::move(constant_folded_value);
}
ConstantValuePtr getConstantValueOrNull() const override
{
return constant_value;
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::FUNCTION;
}
DataTypePtr getResultType() const override
{
return result_type;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
String function_name;
FunctionOverloadResolverPtr function;
AggregateFunctionPtr aggregate_function;
DataTypePtr result_type;
ConstantValuePtr constant_value;
static constexpr size_t parameters_child_index = 0;
static constexpr size_t arguments_child_index = 1;
static constexpr size_t window_child_index = 2;
static constexpr size_t children_size = window_child_index + 1;
};
}

View File

@ -0,0 +1,332 @@
#include <Analyzer/IQueryTreeNode.h>
#include <unordered_map>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTWithAlias.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
}
const char * toString(QueryTreeNodeType type)
{
switch (type)
{
case QueryTreeNodeType::IDENTIFIER: return "IDENTIFIER";
case QueryTreeNodeType::MATCHER: return "MATCHER";
case QueryTreeNodeType::TRANSFORMER: return "TRANSFORMER";
case QueryTreeNodeType::LIST: return "LIST";
case QueryTreeNodeType::CONSTANT: return "CONSTANT";
case QueryTreeNodeType::FUNCTION: return "FUNCTION";
case QueryTreeNodeType::COLUMN: return "COLUMN";
case QueryTreeNodeType::LAMBDA: return "LAMBDA";
case QueryTreeNodeType::SORT: return "SORT";
case QueryTreeNodeType::INTERPOLATE: return "INTERPOLATE";
case QueryTreeNodeType::WINDOW: return "WINDOW";
case QueryTreeNodeType::TABLE: return "TABLE";
case QueryTreeNodeType::TABLE_FUNCTION: return "TABLE_FUNCTION";
case QueryTreeNodeType::QUERY: return "QUERY";
case QueryTreeNodeType::ARRAY_JOIN: return "ARRAY_JOIN";
case QueryTreeNodeType::JOIN: return "JOIN";
case QueryTreeNodeType::UNION: return "UNION";
}
}
IQueryTreeNode::IQueryTreeNode(size_t children_size, size_t weak_pointers_size)
{
children.resize(children_size);
weak_pointers.resize(weak_pointers_size);
}
IQueryTreeNode::IQueryTreeNode(size_t children_size)
{
children.resize(children_size);
}
namespace
{
using NodePair = std::pair<const IQueryTreeNode *, const IQueryTreeNode *>;
struct NodePairHash
{
size_t operator()(const NodePair & node_pair) const
{
auto hash = std::hash<const IQueryTreeNode *>();
size_t result = 0;
boost::hash_combine(result, hash(node_pair.first));
boost::hash_combine(result, hash(node_pair.second));
return result;
}
};
}
bool IQueryTreeNode::isEqual(const IQueryTreeNode & rhs) const
{
std::vector<NodePair> nodes_to_process;
std::unordered_set<NodePair, NodePairHash> equals_pairs;
nodes_to_process.emplace_back(this, &rhs);
while (!nodes_to_process.empty())
{
auto nodes_to_compare = nodes_to_process.back();
nodes_to_process.pop_back();
const auto * lhs_node_to_compare = nodes_to_compare.first;
const auto * rhs_node_to_compare = nodes_to_compare.second;
if (equals_pairs.contains(std::make_pair(lhs_node_to_compare, rhs_node_to_compare)))
continue;
assert(lhs_node_to_compare);
assert(rhs_node_to_compare);
if (lhs_node_to_compare->getNodeType() != rhs_node_to_compare->getNodeType() ||
lhs_node_to_compare->alias != rhs_node_to_compare->alias ||
!lhs_node_to_compare->isEqualImpl(*rhs_node_to_compare))
{
return false;
}
const auto & lhs_children = lhs_node_to_compare->children;
const auto & rhs_children = rhs_node_to_compare->children;
size_t lhs_children_size = lhs_children.size();
if (lhs_children_size != rhs_children.size())
return false;
for (size_t i = 0; i < lhs_children_size; ++i)
{
const auto & lhs_child = lhs_children[i];
const auto & rhs_child = rhs_children[i];
if (!lhs_child && !rhs_child)
continue;
else if (lhs_child && !rhs_child)
return false;
else if (!lhs_child && rhs_child)
return false;
nodes_to_process.emplace_back(lhs_child.get(), rhs_child.get());
}
const auto & lhs_weak_pointers = lhs_node_to_compare->weak_pointers;
const auto & rhs_weak_pointers = rhs_node_to_compare->weak_pointers;
size_t lhs_weak_pointers_size = lhs_weak_pointers.size();
if (lhs_weak_pointers_size != rhs_weak_pointers.size())
return false;
for (size_t i = 0; i < lhs_weak_pointers_size; ++i)
{
auto lhs_strong_pointer = lhs_weak_pointers[i].lock();
auto rhs_strong_pointer = rhs_weak_pointers[i].lock();
if (!lhs_strong_pointer && !rhs_strong_pointer)
continue;
else if (lhs_strong_pointer && !rhs_strong_pointer)
return false;
else if (!lhs_strong_pointer && rhs_strong_pointer)
return false;
nodes_to_process.emplace_back(lhs_strong_pointer.get(), rhs_strong_pointer.get());
}
equals_pairs.emplace(lhs_node_to_compare, rhs_node_to_compare);
}
return true;
}
IQueryTreeNode::Hash IQueryTreeNode::getTreeHash() const
{
HashState hash_state;
std::unordered_map<const IQueryTreeNode *, size_t> node_to_identifier;
std::vector<const IQueryTreeNode *> nodes_to_process;
nodes_to_process.push_back(this);
while (!nodes_to_process.empty())
{
const auto * node_to_process = nodes_to_process.back();
nodes_to_process.pop_back();
auto node_identifier_it = node_to_identifier.find(node_to_process);
if (node_identifier_it != node_to_identifier.end())
{
hash_state.update(node_identifier_it->second);
continue;
}
node_to_identifier.emplace(node_to_process, node_to_identifier.size());
hash_state.update(static_cast<size_t>(node_to_process->getNodeType()));
if (!node_to_process->alias.empty())
{
hash_state.update(node_to_process->alias.size());
hash_state.update(node_to_process->alias);
}
node_to_process->updateTreeHashImpl(hash_state);
hash_state.update(node_to_process->children.size());
for (const auto & node_to_process_child : node_to_process->children)
{
if (!node_to_process_child)
continue;
nodes_to_process.push_back(node_to_process_child.get());
}
hash_state.update(node_to_process->weak_pointers.size());
for (const auto & weak_pointer : node_to_process->weak_pointers)
{
auto strong_pointer = weak_pointer.lock();
if (!strong_pointer)
continue;
nodes_to_process.push_back(strong_pointer.get());
}
}
Hash result;
hash_state.get128(result);
return result;
}
QueryTreeNodePtr IQueryTreeNode::clone() const
{
/** Clone tree with this node as root.
*
* Algorithm
* For each node we clone state and also create mapping old pointer to new pointer.
* For each cloned node we update weak pointers array.
*
* After that we can update pointer in weak pointers array using old pointer to new pointer mapping.
*/
std::unordered_map<const IQueryTreeNode *, QueryTreeNodePtr> old_pointer_to_new_pointer;
std::vector<QueryTreeNodeWeakPtr *> weak_pointers_to_update_after_clone;
QueryTreeNodePtr result_cloned_node_place;
std::vector<std::pair<const IQueryTreeNode *, QueryTreeNodePtr *>> nodes_to_clone;
nodes_to_clone.emplace_back(this, &result_cloned_node_place);
while (!nodes_to_clone.empty())
{
const auto [node_to_clone, place_for_cloned_node] = nodes_to_clone.back();
nodes_to_clone.pop_back();
auto node_clone = node_to_clone->cloneImpl();
*place_for_cloned_node = node_clone;
node_clone->setAlias(node_to_clone->alias);
node_clone->setOriginalAST(node_to_clone->original_ast);
node_clone->children = node_to_clone->children;
node_clone->weak_pointers = node_to_clone->weak_pointers;
old_pointer_to_new_pointer.emplace(node_to_clone, node_clone);
for (auto & child : node_clone->children)
{
if (!child)
continue;
nodes_to_clone.emplace_back(child.get(), &child);
}
for (auto & weak_pointer : node_clone->weak_pointers)
{
weak_pointers_to_update_after_clone.push_back(&weak_pointer);
}
}
/** Update weak pointers to new pointers if they were changed during clone.
* To do this we check old pointer to new pointer map, if weak pointer
* strong pointer exists as old pointer in map, reinitialize weak pointer with new pointer.
*/
for (auto & weak_pointer_ptr : weak_pointers_to_update_after_clone)
{
assert(weak_pointer_ptr);
auto strong_pointer = weak_pointer_ptr->lock();
auto it = old_pointer_to_new_pointer.find(strong_pointer.get());
/** If node had weak pointer to some other node and this node is not part of cloned subtree do not update weak pointer.
* It will continue to point to previous location and it is expected.
*
* Example: SELECT id FROM test_table;
* During analysis `id` is resolved as column node and `test_table` is column source.
* If we clone `id` column, result column node weak source pointer will point to the same `test_table` column source.
*/
if (it == old_pointer_to_new_pointer.end())
continue;
*weak_pointer_ptr = it->second;
}
return result_cloned_node_place;
}
ASTPtr IQueryTreeNode::toAST() const
{
auto converted_node = toASTImpl();
if (auto * ast_with_alias = dynamic_cast<ASTWithAlias *>(converted_node.get()))
converted_node->setAlias(alias);
return converted_node;
}
String IQueryTreeNode::formatOriginalASTForErrorMessage() const
{
if (!original_ast)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Original AST was not set");
return original_ast->formatForErrorMessage();
}
String IQueryTreeNode::formatConvertedASTForErrorMessage() const
{
return toAST()->formatForErrorMessage();
}
String IQueryTreeNode::dumpTree() const
{
WriteBufferFromOwnString buffer;
dumpTree(buffer);
return buffer.str();
}
size_t IQueryTreeNode::FormatState::getNodeId(const IQueryTreeNode * node)
{
auto [it, _] = node_to_id.emplace(node, node_to_id.size());
return it->second;
}
void IQueryTreeNode::dumpTree(WriteBuffer & buffer) const
{
FormatState state;
dumpTreeImpl(buffer, state, 0);
}
}

View File

@ -0,0 +1,282 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include <Common/TypePromotion.h>
#include <DataTypes/IDataType.h>
#include <Parsers/IAST_fwd.h>
#include <Analyzer/Identifier.h>
#include <Analyzer/ConstantValue.h>
class SipHash;
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
extern const int LOGICAL_ERROR;
}
class WriteBuffer;
/// Query tree node type
enum class QueryTreeNodeType
{
IDENTIFIER,
MATCHER,
TRANSFORMER,
LIST,
CONSTANT,
FUNCTION,
COLUMN,
LAMBDA,
SORT,
INTERPOLATE,
WINDOW,
TABLE,
TABLE_FUNCTION,
QUERY,
ARRAY_JOIN,
JOIN,
UNION
};
/// Convert query tree node type to string
const char * toString(QueryTreeNodeType type);
/** Query tree is semantical representation of query.
* Query tree node represent node in query tree.
* IQueryTreeNode is base class for all query tree nodes.
*
* Important property of query tree is that each query tree node can contain weak pointers to other
* query tree nodes. Keeping weak pointer to other query tree nodes can be useful for example for column
* to keep weak pointer to column source, column source can be table, lambda, subquery and preserving of
* such information can significantly simplify query planning.
*
* Another important property of query tree it must be convertible to AST without losing information.
*/
class IQueryTreeNode;
using QueryTreeNodePtr = std::shared_ptr<IQueryTreeNode>;
using QueryTreeNodes = std::vector<QueryTreeNodePtr>;
using QueryTreeNodeWeakPtr = std::weak_ptr<IQueryTreeNode>;
using QueryTreeWeakNodes = std::vector<QueryTreeNodeWeakPtr>;
class IQueryTreeNode : public TypePromotion<IQueryTreeNode>
{
public:
virtual ~IQueryTreeNode() = default;
/// Get query tree node type
virtual QueryTreeNodeType getNodeType() const = 0;
/// Get query tree node type name
const char * getNodeTypeName() const
{
return toString(getNodeType());
}
/** Get name of query tree node that can be used as part of expression.
* TODO: Projection name, expression name must be refactored in better interface.
*/
virtual String getName() const
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method getName is not supported for {} query node", getNodeTypeName());
}
/** Get result type of query tree node that can be used as part of expression.
* If node does not support this method exception is thrown.
* TODO: Maybe this can be a part of ExpressionQueryTreeNode.
*/
virtual DataTypePtr getResultType() const
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method getResultType is not supported for {} query node", getNodeTypeName());
}
/// Returns true if node has constant value
bool hasConstantValue() const
{
return getConstantValueOrNull() != nullptr;
}
/** Returns constant value with type if node has constant value, and can be replaced with it.
* Examples: scalar subquery, function with constant arguments.
*/
virtual const ConstantValue & getConstantValue() const
{
auto constant_value = getConstantValueOrNull();
if (!constant_value)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Node does not have constant value");
return *constant_value;
}
/// Returns constant value with type if node has constant value or null otherwise
virtual ConstantValuePtr getConstantValueOrNull() const
{
return {};
}
/** Is tree equal to other tree with node root.
*
* Aliases of query tree nodes are compared during isEqual call.
* Original ASTs of query tree nodes are not compared during isEqual call.
*/
bool isEqual(const IQueryTreeNode & rhs) const;
using Hash = std::pair<UInt64, UInt64>;
using HashState = SipHash;
/** Get tree hash identifying current tree
*
* Alias of query tree node is part of query tree hash.
* Original AST is not part of query tree hash.
*/
Hash getTreeHash() const;
/// Get a deep copy of the query tree
QueryTreeNodePtr clone() const;
/// Returns true if node has alias, false otherwise
bool hasAlias() const
{
return !alias.empty();
}
/// Get node alias
const String & getAlias() const
{
return alias;
}
/// Set node alias
void setAlias(String alias_value)
{
alias = std::move(alias_value);
}
/// Remove node alias
void removeAlias()
{
alias = {};
}
/// Returns true if query tree node has original AST, false otherwise
bool hasOriginalAST() const
{
return original_ast != nullptr;
}
/// Get query tree node original AST
const ASTPtr & getOriginalAST() const
{
return original_ast;
}
/** Set query tree node original AST.
* This AST will not be modified later.
*/
void setOriginalAST(ASTPtr original_ast_value)
{
original_ast = std::move(original_ast_value);
}
/** If query tree has original AST format it for error message.
* Otherwise exception is thrown.
*/
String formatOriginalASTForErrorMessage() const;
/// Convert query tree to AST
ASTPtr toAST() const;
/// Convert query tree to AST and then format it for error message.
String formatConvertedASTForErrorMessage() const;
/** Format AST for error message.
* If original AST exists use `formatOriginalASTForErrorMessage`.
* Otherwise use `formatConvertedASTForErrorMessage`.
*/
String formatASTForErrorMessage() const
{
if (original_ast)
return formatOriginalASTForErrorMessage();
return formatConvertedASTForErrorMessage();
}
/// Dump query tree to string
String dumpTree() const;
/// Dump query tree to buffer
void dumpTree(WriteBuffer & buffer) const;
class FormatState
{
public:
size_t getNodeId(const IQueryTreeNode * node);
private:
std::unordered_map<const IQueryTreeNode *, size_t> node_to_id;
};
/** Dump query tree to buffer starting with indent.
*
* Node must also dump its children.
*/
virtual void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const = 0;
/// Get query tree node children
QueryTreeNodes & getChildren()
{
return children;
}
/// Get query tree node children
const QueryTreeNodes & getChildren() const
{
return children;
}
protected:
/** Construct query tree node.
* Resize children to children size.
* Resize weak pointers to weak pointers size.
*/
explicit IQueryTreeNode(size_t children_size, size_t weak_pointers_size);
/// Construct query tree node and resize children to children size
explicit IQueryTreeNode(size_t children_size);
/** Subclass must compare its internal state with rhs node internal state and do not compare children or weak pointers to other
* query tree nodes.
*/
virtual bool isEqualImpl(const IQueryTreeNode & rhs) const = 0;
/** Subclass must update tree hash with its internal state and do not update tree hash for children or weak pointers to other
* query tree nodes.
*/
virtual void updateTreeHashImpl(HashState & hash_state) const = 0;
/** Subclass must clone its internal state and do not clone children or weak pointers to other
* query tree nodes.
*/
virtual QueryTreeNodePtr cloneImpl() const = 0;
/// Subclass must convert its internal state and its children to AST
virtual ASTPtr toASTImpl() const = 0;
QueryTreeNodes children;
QueryTreeWeakNodes weak_pointers;
private:
String alias;
ASTPtr original_ast;
};
}

View File

@ -0,0 +1,38 @@
#pragma once
#include <Interpreters/Context_fwd.h>
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/** After query tree is build it can be later processed by query tree passes.
* This is abstract base class for all query tree passes.
*
* Query tree pass can make query tree modifications, after each pass query tree must be valid.
* Query tree pass must be isolated and perform only necessary query tree modifications for doing its job.
* Dependencies between passes must be avoided.
*/
class IQueryTreePass;
using QueryTreePassPtr = std::shared_ptr<IQueryTreePass>;
using QueryTreePasses = std::vector<QueryTreePassPtr>;
class IQueryTreePass
{
public:
virtual ~IQueryTreePass() = default;
/// Get query tree pass name
virtual String getName() = 0;
/// Get query tree pass description
virtual String getDescription() = 0;
/// Run pass over query tree
virtual void run(QueryTreeNodePtr query_tree_node, ContextPtr context) = 0;
};
}

412
src/Analyzer/Identifier.h Normal file
View File

@ -0,0 +1,412 @@
#pragma once
#include <vector>
#include <string>
#include <fmt/core.h>
#include <fmt/format.h>
#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/join.hpp>
namespace DB
{
/** Identifier consists from identifier parts.
* Each identifier part is arbitrary long sequence of digits, underscores, lowercase and uppercase letters.
* Example: a, a.b, a.b.c.
*/
class Identifier
{
public:
Identifier() = default;
/// Create Identifier from parts
explicit Identifier(const std::vector<std::string> & parts_)
: parts(parts_)
, full_name(boost::algorithm::join(parts, "."))
{
}
/// Create Identifier from parts
explicit Identifier(std::vector<std::string> && parts_)
: parts(std::move(parts_))
, full_name(boost::algorithm::join(parts, "."))
{
}
/// Create Identifier from full name, full name is split with '.' as separator.
explicit Identifier(const std::string & full_name_)
: full_name(full_name_)
{
boost::split(parts, full_name, [](char c) { return c == '.'; });
}
/// Create Identifier from full name, full name is split with '.' as separator.
explicit Identifier(std::string && full_name_)
: full_name(std::move(full_name_))
{
boost::split(parts, full_name, [](char c) { return c == '.'; });
}
const std::string & getFullName() const
{
return full_name;
}
const std::vector<std::string> & getParts() const
{
return parts;
}
size_t getPartsSize() const
{
return parts.size();
}
bool empty() const
{
return parts.empty();
}
bool isEmpty() const
{
return parts.empty();
}
bool isShort() const
{
return parts.size() == 1;
}
bool isCompound() const
{
return parts.size() > 1;
}
const std::string & at(size_t index) const
{
if (index >= parts.size())
throw std::out_of_range("identifier access part is out of range");
return parts[index];
}
const std::string & operator[](size_t index) const
{
return parts[index];
}
const std::string & front() const
{
return parts.front();
}
const std::string & back() const
{
return parts.back();
}
/// Returns true, if identifier starts with part, false otherwise
bool startsWith(const std::string_view & part)
{
return !parts.empty() && parts[0] == part;
}
/// Returns true, if identifier ends with part, false otherwise
bool endsWith(const std::string_view & part)
{
return !parts.empty() && parts.back() == part;
}
using const_iterator = std::vector<std::string>::const_iterator;
const_iterator begin() const
{
return parts.begin();
}
const_iterator end() const
{
return parts.end();
}
void popFirst(size_t parts_to_remove_size)
{
assert(parts_to_remove_size <= parts.size());
size_t parts_size = parts.size();
std::vector<std::string> result_parts;
result_parts.reserve(parts_size - parts_to_remove_size);
for (size_t i = parts_to_remove_size; i < parts_size; ++i)
result_parts.push_back(std::move(parts[i]));
parts = std::move(result_parts);
full_name = boost::algorithm::join(parts, ".");
}
void popFirst()
{
return popFirst(1);
}
void popLast(size_t parts_to_remove_size)
{
assert(parts_to_remove_size <= parts.size());
for (size_t i = 0; i < parts_to_remove_size; ++i)
{
size_t last_part_size = parts.back().size();
parts.pop_back();
bool is_not_last = !parts.empty();
full_name.resize(full_name.size() - (last_part_size + static_cast<size_t>(is_not_last)));
}
}
void popLast()
{
return popLast(1);
}
void pop_back() /// NOLINT
{
popLast();
}
void push_back(std::string && part) /// NOLINT
{
parts.push_back(std::move(part));
full_name += '.';
full_name += parts.back();
}
void push_back(const std::string & part) /// NOLINT
{
parts.push_back(part);
full_name += '.';
full_name += parts.back();
}
template <typename ...Args>
void emplace_back(Args&&... args) /// NOLINT
{
parts.emplace_back(std::forward<Args>(args)...);
full_name += '.';
full_name += parts.back();
}
private:
std::vector<std::string> parts;
std::string full_name;
};
inline bool operator==(const Identifier & lhs, const Identifier & rhs)
{
return lhs.getFullName() == rhs.getFullName();
}
inline bool operator!=(const Identifier & lhs, const Identifier & rhs)
{
return !(lhs == rhs);
}
inline std::ostream & operator<<(std::ostream & stream, const Identifier & identifier)
{
stream << identifier.getFullName();
return stream;
}
using Identifiers = std::vector<Identifier>;
/// View for Identifier
class IdentifierView
{
public:
IdentifierView() = default;
IdentifierView(const Identifier & identifier) /// NOLINT
: full_name_view(identifier.getFullName())
, parts_start_it(identifier.begin())
, parts_end_it(identifier.end())
{}
std::string_view getFullName() const
{
return full_name_view;
}
size_t getPartsSize() const
{
return parts_end_it - parts_start_it;
}
bool empty() const
{
return parts_start_it == parts_end_it;
}
bool isEmpty() const
{
return parts_start_it == parts_end_it;
}
bool isShort() const
{
return getPartsSize() == 1;
}
bool isCompound() const
{
return getPartsSize() > 1;
}
std::string_view at(size_t index) const
{
if (index >= getPartsSize())
throw std::out_of_range("identifier access part is out of range");
return *(parts_start_it + index);
}
std::string_view operator[](size_t index) const
{
return *(parts_start_it + index);
}
std::string_view front() const
{
return *parts_start_it;
}
std::string_view back() const
{
return *(parts_end_it - 1);
}
bool startsWith(std::string_view part) const
{
return !isEmpty() && *parts_start_it == part;
}
bool endsWith(std::string_view part) const
{
return !isEmpty() && *(parts_end_it - 1) == part;
}
void popFirst(size_t parts_to_remove_size)
{
assert(parts_to_remove_size <= getPartsSize());
for (size_t i = 0; i < parts_to_remove_size; ++i)
{
size_t part_size = parts_start_it->size();
++parts_start_it;
bool is_not_last = parts_start_it != parts_end_it;
full_name_view.remove_prefix(part_size + is_not_last);
}
}
void popFirst()
{
popFirst(1);
}
void popLast(size_t parts_to_remove_size)
{
assert(parts_to_remove_size <= getPartsSize());
for (size_t i = 0; i < parts_to_remove_size; ++i)
{
size_t last_part_size = (parts_end_it - 1)->size();
--parts_end_it;
bool is_not_last = parts_start_it != parts_end_it;
full_name_view.remove_suffix(last_part_size + is_not_last);
}
}
void popLast()
{
popLast(1);
}
using const_iterator = Identifier::const_iterator;
const_iterator begin() const
{
return parts_start_it;
}
const_iterator end() const
{
return parts_end_it;
}
private:
std::string_view full_name_view;
const_iterator parts_start_it;
const_iterator parts_end_it;
};
inline bool operator==(const IdentifierView & lhs, const IdentifierView & rhs)
{
return lhs.getFullName() == rhs.getFullName();
}
inline bool operator!=(const IdentifierView & lhs, const IdentifierView & rhs)
{
return !(lhs == rhs);
}
inline std::ostream & operator<<(std::ostream & stream, const IdentifierView & identifier_view)
{
stream << identifier_view.getFullName();
return stream;
}
}
/// See https://fmt.dev/latest/api.html#formatting-user-defined-types
template <>
struct fmt::formatter<DB::Identifier>
{
constexpr static auto parse(format_parse_context & ctx)
{
const auto * it = ctx.begin();
const auto * end = ctx.end();
/// Only support {}.
if (it != end && *it != '}')
throw format_error("invalid format");
return it;
}
template <typename FormatContext>
auto format(const DB::Identifier & identifier, FormatContext & ctx)
{
return format_to(ctx.out(), "{}", identifier.getFullName());
}
};
template <>
struct fmt::formatter<DB::IdentifierView>
{
constexpr static auto parse(format_parse_context & ctx)
{
const auto * it = ctx.begin();
const auto * end = ctx.end();
/// Only support {}.
if (it != end && *it != '}')
throw format_error("invalid format");
return it;
}
template <typename FormatContext>
auto format(const DB::IdentifierView & identifier_view, FormatContext & ctx)
{
return format_to(ctx.out(), "{}", identifier_view.getFullName());
}
};

View File

@ -0,0 +1,75 @@
#include <Analyzer/IdentifierNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Parsers/ASTIdentifier.h>
namespace DB
{
IdentifierNode::IdentifierNode(Identifier identifier_)
: IQueryTreeNode(children_size)
, identifier(std::move(identifier_))
{}
IdentifierNode::IdentifierNode(Identifier identifier_, TableExpressionModifiers table_expression_modifiers_)
: IQueryTreeNode(children_size)
, identifier(std::move(identifier_))
, table_expression_modifiers(std::move(table_expression_modifiers_))
{}
void IdentifierNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "IDENTIFIER id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
buffer << ", identifier: " << identifier.getFullName();
if (table_expression_modifiers)
{
buffer << ", ";
table_expression_modifiers->dump(buffer);
}
}
bool IdentifierNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const IdentifierNode &>(rhs);
if (table_expression_modifiers && rhs_typed.table_expression_modifiers && table_expression_modifiers != rhs_typed.table_expression_modifiers)
return false;
else if (table_expression_modifiers && !rhs_typed.table_expression_modifiers)
return false;
else if (!table_expression_modifiers && rhs_typed.table_expression_modifiers)
return false;
return identifier == rhs_typed.identifier;
}
void IdentifierNode::updateTreeHashImpl(HashState & state) const
{
const auto & identifier_name = identifier.getFullName();
state.update(identifier_name.size());
state.update(identifier_name);
if (table_expression_modifiers)
table_expression_modifiers->updateTreeHash(state);
}
QueryTreeNodePtr IdentifierNode::cloneImpl() const
{
return std::make_shared<IdentifierNode>(identifier);
}
ASTPtr IdentifierNode::toASTImpl() const
{
auto identifier_parts = identifier.getParts();
return std::make_shared<ASTIdentifier>(std::move(identifier_parts));
}
}

View File

@ -0,0 +1,76 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/Identifier.h>
#include <Analyzer/TableExpressionModifiers.h>
namespace DB
{
/** Identifier node represents identifier in query tree.
* Example: SELECT a FROM test_table.
* a - is identifier.
* test_table - is identifier.
*
* Identifier resolution must be done during query analysis pass.
*/
class IdentifierNode final : public IQueryTreeNode
{
public:
/// Construct identifier node with identifier
explicit IdentifierNode(Identifier identifier_);
/** Construct identifier node with identifier and table expression modifiers
* when identifier node is part of JOIN TREE.
*
* Example: SELECT * FROM test_table SAMPLE 0.1 OFFSET 0.1 FINAL
*/
explicit IdentifierNode(Identifier identifier_, TableExpressionModifiers table_expression_modifiers_);
/// Get identifier
const Identifier & getIdentifier() const
{
return identifier;
}
/// Return true if identifier node has table expression modifiers, false otherwise
bool hasTableExpressionModifiers() const
{
return table_expression_modifiers.has_value();
}
/// Get table expression modifiers
const std::optional<TableExpressionModifiers> & getTableExpressionModifiers() const
{
return table_expression_modifiers;
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::IDENTIFIER;
}
String getName() const override
{
return identifier.getFullName();
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
Identifier identifier;
std::optional<TableExpressionModifiers> table_expression_modifiers;
static constexpr size_t children_size = 0;
};
}

View File

@ -0,0 +1,87 @@
#pragma once
#include <Common/Exception.h>
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/** Visitor that traverse query tree in depth.
* Derived class must implement `visitImpl` method.
* Additionally subclass can control if child need to be visited using `needChildVisit` method, by
* default all node children are visited.
* By default visitor traverse tree from top to bottom, if bottom to top traverse is required subclass
* can override `shouldTraverseTopToBottom` method.
*
* Usage example:
* class FunctionsVisitor : public InDepthQueryTreeVisitor<FunctionsVisitor>
* {
* void visitImpl(VisitQueryTreeNodeType & query_tree_node)
* {
* if (query_tree_node->getNodeType() == QueryTreeNodeType::FUNCTION)
* processFunctionNode(query_tree_node);
* }
* }
*/
template <typename Derived, bool const_visitor = false>
class InDepthQueryTreeVisitor
{
public:
using VisitQueryTreeNodeType = std::conditional_t<const_visitor, const QueryTreeNodePtr, QueryTreeNodePtr>;
/// Return true if visitor should traverse tree top to bottom, false otherwise
bool shouldTraverseTopToBottom() const
{
return true;
}
/// Return true if visitor should visit child, false otherwise
bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child [[maybe_unused]])
{
return true;
}
void visit(VisitQueryTreeNodeType & query_tree_node)
{
bool traverse_top_to_bottom = getDerived().shouldTraverseTopToBottom();
if (!traverse_top_to_bottom)
visitChildren(query_tree_node);
getDerived().visitImpl(query_tree_node);
if (traverse_top_to_bottom)
visitChildren(query_tree_node);
}
private:
Derived & getDerived()
{
return *static_cast<Derived *>(this);
}
const Derived & getDerived() const
{
return *static_cast<Derived *>(this);
}
void visitChildren(VisitQueryTreeNodeType & expression)
{
for (auto & child : expression->getChildren())
{
if (!child)
continue;
bool need_visit_child = getDerived().needChildVisit(expression, child);
if (need_visit_child)
visit(child);
}
}
};
template <typename Derived>
using ConstInDepthQueryTreeVisitor = InDepthQueryTreeVisitor<Derived, true /*const_visitor*/>;
}

View File

@ -0,0 +1,66 @@
#include <Analyzer/InterpolateNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Parsers/ASTInterpolateElement.h>
namespace DB
{
InterpolateNode::InterpolateNode(QueryTreeNodePtr expression_, QueryTreeNodePtr interpolate_expression_)
: IQueryTreeNode(children_size)
{
children[expression_child_index] = std::move(expression_);
children[interpolate_expression_child_index] = std::move(interpolate_expression_);
}
String InterpolateNode::getName() const
{
String result = getExpression()->getName();
result += " AS ";
result += getInterpolateExpression()->getName();
return result;
}
void InterpolateNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "INTERPOLATE id: " << format_state.getNodeId(this);
buffer << '\n' << std::string(indent + 2, ' ') << "EXPRESSION\n";
getExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
buffer << '\n' << std::string(indent + 2, ' ') << "INTERPOLATE_EXPRESSION\n";
getInterpolateExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
}
bool InterpolateNode::isEqualImpl(const IQueryTreeNode &) const
{
/// No state in interpolate node
return true;
}
void InterpolateNode::updateTreeHashImpl(HashState &) const
{
/// No state in interpolate node
}
QueryTreeNodePtr InterpolateNode::cloneImpl() const
{
return std::make_shared<InterpolateNode>(nullptr /*expression*/, nullptr /*interpolate_expression*/);
}
ASTPtr InterpolateNode::toASTImpl() const
{
auto result = std::make_shared<ASTInterpolateElement>();
result->column = getExpression()->toAST()->getColumnName();
result->children.push_back(getInterpolateExpression()->toAST());
result->expr = result->children.back();
return result;
}
}

View File

@ -0,0 +1,72 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
namespace DB
{
/** Interpolate node represents expression interpolation in INTERPOLATE section that is part of ORDER BY section in query tree.
*
* Example: SELECT * FROM test_table ORDER BY id WITH FILL INTERPOLATE (value AS value + 1);
* value - expression to interpolate.
* value + 1 - interpolate expression.
*/
class InterpolateNode;
using InterpolateNodePtr = std::shared_ptr<InterpolateNode>;
class InterpolateNode final : public IQueryTreeNode
{
public:
/// Initialize interpolate node with expression and interpolate expression
explicit InterpolateNode(QueryTreeNodePtr expression_, QueryTreeNodePtr interpolate_expression_);
/// Get expression to interpolate
const QueryTreeNodePtr & getExpression() const
{
return children[expression_child_index];
}
/// Get expression to interpolate
QueryTreeNodePtr & getExpression()
{
return children[expression_child_index];
}
/// Get interpolate expression
const QueryTreeNodePtr & getInterpolateExpression() const
{
return children[interpolate_expression_child_index];
}
/// Get interpolate expression
QueryTreeNodePtr & getInterpolateExpression()
{
return children[interpolate_expression_child_index];
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::INTERPOLATE;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
static constexpr size_t expression_child_index = 0;
static constexpr size_t interpolate_expression_child_index = 1;
static constexpr size_t children_size = interpolate_expression_child_index + 1;
};
}

116
src/Analyzer/JoinNode.cpp Normal file
View File

@ -0,0 +1,116 @@
#include <Analyzer/JoinNode.h>
#include <Analyzer/ListNode.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Analyzer/Utils.h>
namespace DB
{
JoinNode::JoinNode(QueryTreeNodePtr left_table_expression_,
QueryTreeNodePtr right_table_expression_,
QueryTreeNodePtr join_expression_,
JoinLocality locality_,
JoinStrictness strictness_,
JoinKind kind_)
: IQueryTreeNode(children_size)
, locality(locality_)
, strictness(strictness_)
, kind(kind_)
{
children[left_table_expression_child_index] = std::move(left_table_expression_);
children[right_table_expression_child_index] = std::move(right_table_expression_);
children[join_expression_child_index] = std::move(join_expression_);
}
ASTPtr JoinNode::toASTTableJoin() const
{
auto join_ast = std::make_shared<ASTTableJoin>();
join_ast->locality = locality;
join_ast->strictness = strictness;
join_ast->kind = kind;
if (children[join_expression_child_index])
{
auto join_expression_ast = children[join_expression_child_index]->toAST();
if (children[join_expression_child_index]->getNodeType() == QueryTreeNodeType::LIST)
join_ast->using_expression_list = std::move(join_expression_ast);
else
join_ast->on_expression = std::move(join_expression_ast);
}
return join_ast;
}
void JoinNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "JOIN id: " << format_state.getNodeId(this);
if (locality != JoinLocality::Unspecified)
buffer << ", locality: " << toString(locality);
if (strictness != JoinStrictness::Unspecified)
buffer << ", strictness: " << toString(strictness);
buffer << ", kind: " << toString(kind);
buffer << '\n' << std::string(indent + 2, ' ') << "LEFT TABLE EXPRESSION\n";
getLeftTableExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
buffer << '\n' << std::string(indent + 2, ' ') << "RIGHT TABLE EXPRESSION\n";
getRightTableExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
if (getJoinExpression())
{
buffer << '\n' << std::string(indent + 2, ' ') << "JOIN EXPRESSION\n";
getJoinExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
}
}
bool JoinNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const JoinNode &>(rhs);
return locality == rhs_typed.locality && strictness == rhs_typed.strictness && kind == rhs_typed.kind;
}
void JoinNode::updateTreeHashImpl(HashState & state) const
{
state.update(locality);
state.update(strictness);
state.update(kind);
}
QueryTreeNodePtr JoinNode::cloneImpl() const
{
return std::make_shared<JoinNode>(getLeftTableExpression(), getRightTableExpression(), getJoinExpression(), locality, strictness, kind);
}
ASTPtr JoinNode::toASTImpl() const
{
ASTPtr tables_in_select_query_ast = std::make_shared<ASTTablesInSelectQuery>();
addTableExpressionOrJoinIntoTablesInSelectQuery(tables_in_select_query_ast, children[left_table_expression_child_index]);
size_t join_table_index = tables_in_select_query_ast->children.size();
auto join_ast = toASTTableJoin();
addTableExpressionOrJoinIntoTablesInSelectQuery(tables_in_select_query_ast, children[right_table_expression_child_index]);
auto & table_element = tables_in_select_query_ast->children.at(join_table_index)->as<ASTTablesInSelectQueryElement &>();
table_element.children.push_back(std::move(join_ast));
table_element.table_join = table_element.children.back();
return tables_in_select_query_ast;
}
}

152
src/Analyzer/JoinNode.h Normal file
View File

@ -0,0 +1,152 @@
#pragma once
#include <Core/Joins.h>
#include <Storages/IStorage_fwd.h>
#include <Storages/TableLockHolder.h>
#include <Storages/StorageSnapshot.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/StorageID.h>
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/** Join node represents join in query tree.
*
* For JOIN without join expression, JOIN expression is null.
* Example: SELECT id FROM test_table_1 AS t1, test_table_2 AS t2;
*
* For JOIN with USING, JOIN expression contains list of identifier nodes. These nodes must be resolved
* during query analysis pass.
* Example: SELECT id FROM test_table_1 AS t1 INNER JOIN test_table_2 AS t2 USING (id);
*
* For JOIN with ON, JOIN expression contains single expression.
* Example: SELECT id FROM test_table_1 AS t1 INNER JOIN test_table_2 AS t2 ON t1.id = t2.id;
*/
class JoinNode;
using JoinNodePtr = std::shared_ptr<JoinNode>;
class JoinNode final : public IQueryTreeNode
{
public:
/** Construct join node with left table expression, right table expression and join expression.
* Example: SELECT id FROM test_table_1 INNER JOIN test_table_2 ON expression.
*
* test_table_1 - left table expression.
* test_table_2 - right table expression.
* expression - join expression.
*/
JoinNode(QueryTreeNodePtr left_table_expression_,
QueryTreeNodePtr right_table_expression_,
QueryTreeNodePtr join_expression_,
JoinLocality locality_,
JoinStrictness strictness_,
JoinKind kind_);
/// Get left table expression
const QueryTreeNodePtr & getLeftTableExpression() const
{
return children[left_table_expression_child_index];
}
/// Get left table expression
QueryTreeNodePtr & getLeftTableExpression()
{
return children[left_table_expression_child_index];
}
/// Get right table expression
const QueryTreeNodePtr & getRightTableExpression() const
{
return children[right_table_expression_child_index];
}
/// Get right table expression
QueryTreeNodePtr & getRightTableExpression()
{
return children[right_table_expression_child_index];
}
/// Returns true if join has join expression, false otherwise
bool hasJoinExpression() const
{
return children[join_expression_child_index] != nullptr;
}
/// Get join expression
const QueryTreeNodePtr & getJoinExpression() const
{
return children[join_expression_child_index];
}
/// Get join expression
QueryTreeNodePtr & getJoinExpression()
{
return children[join_expression_child_index];
}
/// Returns true if join has USING join expression, false otherwise
bool isUsingJoinExpression() const
{
return hasJoinExpression() && getJoinExpression()->getNodeType() == QueryTreeNodeType::LIST;
}
/// Returns true if join has ON join expression, false otherwise
bool isOnJoinExpression() const
{
return hasJoinExpression() && getJoinExpression()->getNodeType() != QueryTreeNodeType::LIST;
}
/// Get join locality
JoinLocality getLocality() const
{
return locality;
}
/// Get join strictness
JoinStrictness getStrictness() const
{
return strictness;
}
/// Get join kind
JoinKind getKind() const
{
return kind;
}
/// Convert join node to ASTTableJoin
ASTPtr toASTTableJoin() const;
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::JOIN;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
JoinLocality locality = JoinLocality::Unspecified;
JoinStrictness strictness = JoinStrictness::Unspecified;
JoinKind kind = JoinKind::Inner;
static constexpr size_t left_table_expression_child_index = 0;
static constexpr size_t right_table_expression_child_index = 1;
static constexpr size_t join_expression_child_index = 2;
static constexpr size_t children_size = join_expression_child_index + 1;
};
}

View File

@ -0,0 +1,93 @@
#include <Analyzer/LambdaNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
namespace DB
{
LambdaNode::LambdaNode(Names argument_names_, QueryTreeNodePtr expression_)
: IQueryTreeNode(children_size)
, argument_names(std::move(argument_names_))
{
auto arguments_list_node = std::make_shared<ListNode>();
auto & nodes = arguments_list_node->getNodes();
size_t argument_names_size = argument_names.size();
nodes.reserve(argument_names_size);
for (size_t i = 0; i < argument_names_size; ++i)
nodes.push_back(std::make_shared<IdentifierNode>(Identifier{argument_names[i]}));
children[arguments_child_index] = std::move(arguments_list_node);
children[expression_child_index] = std::move(expression_);
}
void LambdaNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "LAMBDA id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
const auto & arguments = getArguments();
if (!arguments.getNodes().empty())
{
buffer << '\n' << std::string(indent + 2, ' ') << "ARGUMENTS " << '\n';
getArguments().dumpTreeImpl(buffer, format_state, indent + 4);
}
buffer << '\n' << std::string(indent + 2, ' ') << "EXPRESSION " << '\n';
getExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
}
String LambdaNode::getName() const
{
return "lambda(" + children[arguments_child_index]->getName() + ") -> " + children[expression_child_index]->getName();
}
bool LambdaNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const LambdaNode &>(rhs);
return argument_names == rhs_typed.argument_names;
}
void LambdaNode::updateTreeHashImpl(HashState & state) const
{
state.update(argument_names.size());
for (const auto & argument_name : argument_names)
{
state.update(argument_name.size());
state.update(argument_name);
}
}
QueryTreeNodePtr LambdaNode::cloneImpl() const
{
return std::make_shared<LambdaNode>(argument_names, getExpression());
}
ASTPtr LambdaNode::toASTImpl() const
{
auto lambda_function_arguments_ast = std::make_shared<ASTExpressionList>();
auto tuple_function = std::make_shared<ASTFunction>();
tuple_function->name = "tuple";
tuple_function->children.push_back(children[arguments_child_index]->toAST());
tuple_function->arguments = tuple_function->children.back();
lambda_function_arguments_ast->children.push_back(std::move(tuple_function));
lambda_function_arguments_ast->children.push_back(children[expression_child_index]->toAST());
auto lambda_function_ast = std::make_shared<ASTFunction>();
lambda_function_ast->name = "lambda";
lambda_function_ast->children.push_back(std::move(lambda_function_arguments_ast));
lambda_function_ast->arguments = lambda_function_ast->children.back();
return lambda_function_ast;
}
}

118
src/Analyzer/LambdaNode.h Normal file
View File

@ -0,0 +1,118 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Analyzer/IdentifierNode.h>
#include <Parsers/ASTFunction.h>
namespace DB
{
/** Lambda node represents lambda expression in query tree.
*
* Lambda consist of argument names and lambda expression body.
* Lambda expression body does not necessary use lambda arguments. Example: SELECT arrayMap(x -> 1, [1, 2, 3])
*
* Initially lambda is initialized with argument names and lambda body expression.
*
* Lambda expression result type can depend on arguments types.
* Example: WITH (x -> x) as lambda SELECT lambda(1), lambda('string_value').
*
* During query analysis pass lambdas must be resolved.
* Lambda resolve must set concrete lambda arguments and resolve lambda expression body.
* In query tree lambda arguments are represented by ListNode.
* If client modified lambda arguments array its size must be equal to initial lambda argument names array.
*
* Examples:
* WITH (x -> x + 1) as lambda SELECT lambda(1);
* SELECT arrayMap(x -> x + 1, [1,2,3]);
*/
class LambdaNode;
using LambdaNodePtr = std::shared_ptr<LambdaNode>;
class LambdaNode final : public IQueryTreeNode
{
public:
/// Initialize lambda with argument names and lambda body expression
explicit LambdaNode(Names argument_names_, QueryTreeNodePtr expression_);
/// Get argument names
const Names & getArgumentNames() const
{
return argument_names;
}
/// Get arguments
const ListNode & getArguments() const
{
return children[arguments_child_index]->as<const ListNode &>();
}
/// Get arguments
ListNode & getArguments()
{
return children[arguments_child_index]->as<ListNode &>();
}
/// Get arguments node
const QueryTreeNodePtr & getArgumentsNode() const
{
return children[arguments_child_index];
}
/// Get arguments node
QueryTreeNodePtr & getArgumentsNode()
{
return children[arguments_child_index];
}
/// Get expression
const QueryTreeNodePtr & getExpression() const
{
return children[expression_child_index];
}
/// Get expression
QueryTreeNodePtr & getExpression()
{
return children[expression_child_index];
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::LAMBDA;
}
String getName() const override;
DataTypePtr getResultType() const override
{
return getExpression()->getResultType();
}
ConstantValuePtr getConstantValueOrNull() const override
{
return getExpression()->getConstantValueOrNull();
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
Names argument_names;
static constexpr size_t arguments_child_index = 0;
static constexpr size_t expression_child_index = 1;
static constexpr size_t children_size = expression_child_index + 1;
};
}

88
src/Analyzer/ListNode.cpp Normal file
View File

@ -0,0 +1,88 @@
#include <Analyzer/ListNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTExpressionList.h>
namespace DB
{
ListNode::ListNode()
: IQueryTreeNode(0 /*children_size*/)
{}
ListNode::ListNode(QueryTreeNodes nodes)
: IQueryTreeNode(0 /*children_size*/)
{
children = std::move(nodes);
}
void ListNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "LIST id: " << format_state.getNodeId(this);
size_t children_size = children.size();
buffer << ", nodes: " << children_size << '\n';
for (size_t i = 0; i < children_size; ++i)
{
const auto & node = children[i];
node->dumpTreeImpl(buffer, format_state, indent + 2);
if (i + 1 != children_size)
buffer << '\n';
}
}
String ListNode::getName() const
{
if (children.empty())
return "";
std::string result;
for (const auto & node : children)
{
result += node->getName();
result += ", ";
}
result.pop_back();
result.pop_back();
return result;
}
bool ListNode::isEqualImpl(const IQueryTreeNode &) const
{
/// No state
return true;
}
void ListNode::updateTreeHashImpl(HashState &) const
{
/// No state
}
QueryTreeNodePtr ListNode::cloneImpl() const
{
return std::make_shared<ListNode>();
}
ASTPtr ListNode::toASTImpl() const
{
auto expression_list_ast = std::make_shared<ASTExpressionList>();
size_t children_size = children.size();
expression_list_ast->children.resize(children_size);
for (size_t i = 0; i < children_size; ++i)
expression_list_ast->children[i] = children[i]->toAST();
return expression_list_ast;
}
}

56
src/Analyzer/ListNode.h Normal file
View File

@ -0,0 +1,56 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/** List node represents list of query tree nodes in query tree.
*
* Example: SELECT column_1, 1, 'constant_value' FROM table.
* column_1, 1, 'constant_value' is list query tree node.
*/
class ListNode;
using ListNodePtr = std::shared_ptr<ListNode>;
class ListNode final : public IQueryTreeNode
{
public:
/// Initialize list node with empty nodes
ListNode();
/// Initialize list node with nodes
explicit ListNode(QueryTreeNodes nodes);
/// Get list nodes
const QueryTreeNodes & getNodes() const
{
return children;
}
/// Get list nodes
QueryTreeNodes & getNodes()
{
return children;
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::LIST;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState &) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
};
}

View File

@ -0,0 +1,329 @@
#include <Analyzer/MatcherNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTAsterisk.h>
#include <Parsers/ASTQualifiedAsterisk.h>
#include <Parsers/ASTColumnsMatcher.h>
#include <Parsers/ASTExpressionList.h>
namespace DB
{
const char * toString(MatcherNodeType matcher_node_type)
{
switch (matcher_node_type)
{
case MatcherNodeType::ASTERISK:
return "ASTERISK";
case MatcherNodeType::COLUMNS_LIST:
return "COLUMNS_LIST";
case MatcherNodeType::COLUMNS_REGEXP:
return "COLUMNS_REGEXP";
}
}
MatcherNode::MatcherNode(ColumnTransformersNodes column_transformers_)
: MatcherNode(MatcherNodeType::ASTERISK,
{} /*qualified_identifier*/,
{} /*columns_identifiers*/,
{} /*columns_matcher*/,
std::move(column_transformers_) /*column_transformers*/)
{
}
MatcherNode::MatcherNode(Identifier qualified_identifier_, ColumnTransformersNodes column_transformers_)
: MatcherNode(MatcherNodeType::ASTERISK,
std::move(qualified_identifier_),
{} /*columns_identifiers*/,
{} /*columns_matcher*/,
std::move(column_transformers_))
{
}
MatcherNode::MatcherNode(std::shared_ptr<re2::RE2> columns_matcher_, ColumnTransformersNodes column_transformers_)
: MatcherNode(MatcherNodeType::COLUMNS_REGEXP,
{} /*qualified_identifier*/,
{} /*columns_identifiers*/,
std::move(columns_matcher_),
std::move(column_transformers_))
{
}
MatcherNode::MatcherNode(Identifier qualified_identifier_, std::shared_ptr<re2::RE2> columns_matcher_, ColumnTransformersNodes column_transformers_)
: MatcherNode(MatcherNodeType::COLUMNS_REGEXP,
std::move(qualified_identifier_),
{} /*columns_identifiers*/,
std::move(columns_matcher_),
std::move(column_transformers_))
{
}
MatcherNode::MatcherNode(Identifiers columns_identifiers_, ColumnTransformersNodes column_transformers_)
: MatcherNode(MatcherNodeType::COLUMNS_LIST,
{} /*qualified_identifier*/,
std::move(columns_identifiers_),
{} /*columns_matcher*/,
std::move(column_transformers_))
{
}
MatcherNode::MatcherNode(Identifier qualified_identifier_, Identifiers columns_identifiers_, ColumnTransformersNodes column_transformers_)
: MatcherNode(MatcherNodeType::COLUMNS_LIST,
std::move(qualified_identifier_),
std::move(columns_identifiers_),
{} /*columns_matcher*/,
std::move(column_transformers_))
{
}
MatcherNode::MatcherNode(MatcherNodeType matcher_type_,
Identifier qualified_identifier_,
Identifiers columns_identifiers_,
std::shared_ptr<re2::RE2> columns_matcher_,
ColumnTransformersNodes column_transformers_)
: IQueryTreeNode(children_size)
, matcher_type(matcher_type_)
, qualified_identifier(qualified_identifier_)
, columns_identifiers(columns_identifiers_)
, columns_matcher(columns_matcher_)
{
auto column_transformers_list_node = std::make_shared<ListNode>();
auto & column_transformers_nodes = column_transformers_list_node->getNodes();
column_transformers_nodes.reserve(column_transformers_.size());
for (auto && column_transformer : column_transformers_)
column_transformers_nodes.emplace_back(std::move(column_transformer));
children[column_transformers_child_index] = std::move(column_transformers_list_node);
columns_identifiers_set.reserve(columns_identifiers.size());
for (auto & column_identifier : columns_identifiers)
columns_identifiers_set.insert(column_identifier.getFullName());
}
bool MatcherNode::isMatchingColumn(const std::string & column_name)
{
if (matcher_type == MatcherNodeType::ASTERISK)
return true;
if (columns_matcher)
return RE2::PartialMatch(column_name, *columns_matcher);
return columns_identifiers_set.contains(column_name);
}
void MatcherNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "MATCHER id: " << format_state.getNodeId(this);
buffer << ", matcher_type: " << toString(matcher_type);
if (!qualified_identifier.empty())
buffer << ", qualified_identifier: " << qualified_identifier.getFullName();
if (columns_matcher)
{
buffer << ", columns_pattern: " << columns_matcher->pattern();
}
else if (matcher_type == MatcherNodeType::COLUMNS_LIST)
{
buffer << ", " << fmt::format("column_identifiers: {}", fmt::join(columns_identifiers, ", "));
}
const auto & column_transformers_list = getColumnTransformers();
if (!column_transformers_list.getNodes().empty())
{
buffer << '\n';
column_transformers_list.dumpTreeImpl(buffer, format_state, indent + 2);
}
}
String MatcherNode::getName() const
{
WriteBufferFromOwnString buffer;
if (!qualified_identifier.empty())
buffer << qualified_identifier.getFullName() << '.';
if (matcher_type == MatcherNodeType::ASTERISK)
{
buffer << '*';
}
else
{
buffer << "COLUMNS(";
if (columns_matcher)
{
buffer << ' ' << columns_matcher->pattern();
}
else if (matcher_type == MatcherNodeType::COLUMNS_LIST)
{
size_t columns_identifiers_size = columns_identifiers.size();
for (size_t i = 0; i < columns_identifiers_size; ++i)
{
buffer << columns_identifiers[i].getFullName();
if (i + 1 != columns_identifiers_size)
buffer << ", ";
}
}
}
buffer << ')';
const auto & column_transformers = getColumnTransformers().getNodes();
size_t column_transformers_size = column_transformers.size();
for (size_t i = 0; i < column_transformers_size; ++i)
{
const auto & column_transformer = column_transformers[i];
buffer << column_transformer->getName();
if (i + 1 != column_transformers_size)
buffer << ' ';
}
return buffer.str();
}
bool MatcherNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const MatcherNode &>(rhs);
if (matcher_type != rhs_typed.matcher_type ||
qualified_identifier != rhs_typed.qualified_identifier ||
columns_identifiers != rhs_typed.columns_identifiers ||
columns_identifiers_set != rhs_typed.columns_identifiers_set)
return false;
const auto & rhs_columns_matcher = rhs_typed.columns_matcher;
if (!columns_matcher && !rhs_columns_matcher)
return true;
else if (columns_matcher && !rhs_columns_matcher)
return false;
else if (!columns_matcher && rhs_columns_matcher)
return false;
return columns_matcher->pattern() == rhs_columns_matcher->pattern();
}
void MatcherNode::updateTreeHashImpl(HashState & hash_state) const
{
hash_state.update(static_cast<size_t>(matcher_type));
const auto & qualified_identifier_full_name = qualified_identifier.getFullName();
hash_state.update(qualified_identifier_full_name.size());
hash_state.update(qualified_identifier_full_name);
for (const auto & identifier : columns_identifiers)
{
const auto & identifier_full_name = identifier.getFullName();
hash_state.update(identifier_full_name.size());
hash_state.update(identifier_full_name);
}
if (columns_matcher)
{
const auto & columns_matcher_pattern = columns_matcher->pattern();
hash_state.update(columns_matcher_pattern.size());
hash_state.update(columns_matcher_pattern);
}
}
QueryTreeNodePtr MatcherNode::cloneImpl() const
{
MatcherNodePtr matcher_node = std::make_shared<MatcherNode>();
matcher_node->matcher_type = matcher_type;
matcher_node->qualified_identifier = qualified_identifier;
matcher_node->columns_identifiers = columns_identifiers;
matcher_node->columns_matcher = columns_matcher;
matcher_node->columns_identifiers_set = columns_identifiers_set;
return matcher_node;
}
ASTPtr MatcherNode::toASTImpl() const
{
ASTPtr result;
if (matcher_type == MatcherNodeType::ASTERISK)
{
if (qualified_identifier.empty())
{
result = std::make_shared<ASTAsterisk>();
}
else
{
auto qualified_asterisk = std::make_shared<ASTQualifiedAsterisk>();
auto identifier_parts = qualified_identifier.getParts();
qualified_asterisk->children.push_back(std::make_shared<ASTIdentifier>(std::move(identifier_parts)));
result = qualified_asterisk;
}
}
else if (columns_matcher)
{
if (qualified_identifier.empty())
{
auto regexp_matcher = std::make_shared<ASTColumnsRegexpMatcher>();
regexp_matcher->setPattern(columns_matcher->pattern());
result = regexp_matcher;
}
else
{
auto regexp_matcher = std::make_shared<ASTQualifiedColumnsRegexpMatcher>();
regexp_matcher->setPattern(columns_matcher->pattern());
auto identifier_parts = qualified_identifier.getParts();
regexp_matcher->children.push_back(std::make_shared<ASTIdentifier>(std::move(identifier_parts)));
result = regexp_matcher;
}
}
else
{
auto column_list = std::make_shared<ASTExpressionList>();
column_list->children.reserve(columns_identifiers.size());
for (const auto & identifier : columns_identifiers)
{
auto identifier_parts = identifier.getParts();
column_list->children.push_back(std::make_shared<ASTIdentifier>(std::move(identifier_parts)));
}
if (qualified_identifier.empty())
{
auto columns_list_matcher = std::make_shared<ASTColumnsListMatcher>();
columns_list_matcher->column_list = std::move(column_list);
result = columns_list_matcher;
}
else
{
auto columns_list_matcher = std::make_shared<ASTQualifiedColumnsListMatcher>();
columns_list_matcher->column_list = std::move(column_list);
auto identifier_parts = qualified_identifier.getParts();
columns_list_matcher->children.push_back(std::make_shared<ASTIdentifier>(std::move(identifier_parts)));
result = columns_list_matcher;
}
}
for (const auto & child : children)
result->children.push_back(child->toAST());
return result;
}
}

172
src/Analyzer/MatcherNode.h Normal file
View File

@ -0,0 +1,172 @@
#pragma once
#include <re2/re2.h>
#include <Analyzer/Identifier.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ColumnTransformers.h>
#include <Parsers/ASTAsterisk.h>
namespace DB
{
/** Matcher query tree node.
* Matcher can be unqualified with identifier and qualified with identifier.
* It can be asterisk or COLUMNS('regexp') or COLUMNS(column_name_1, ...).
* In result we have 6 possible options:
* Unqualified
* 1. *
* 2. COLUMNS('regexp')
* 3. COLUMNS(column_name_1, ...)
*
* Qualified:
* 1. identifier.*
* 2. identifier.COLUMNS('regexp')
* 3. identifier.COLUMNS(column_name_1, ...)
*
* Matcher must be resolved during query analysis pass.
*
* Matchers can be applied to compound expressions.
* Example: SELECT compound_column AS a, a.* FROM test_table.
* Example: SELECT compound_column.* FROM test_table.
*
* Example: SELECT * FROM test_table;
* Example: SELECT test_table.* FROM test_table.
* Example: SELECT a.* FROM test_table AS a.
*
* Additionally each matcher can contain transformers, check ColumnTransformers.h.
* In query tree matchers column transformers are represended as ListNode.
*/
enum class MatcherNodeType
{
ASTERISK,
COLUMNS_REGEXP,
COLUMNS_LIST
};
const char * toString(MatcherNodeType matcher_node_type);
class MatcherNode;
using MatcherNodePtr = std::shared_ptr<MatcherNode>;
class MatcherNode final : public IQueryTreeNode
{
public:
/// Variant unqualified asterisk
explicit MatcherNode(ColumnTransformersNodes column_transformers_ = {});
/// Variant qualified asterisk
explicit MatcherNode(Identifier qualified_identifier_, ColumnTransformersNodes column_transformers_ = {});
/// Variant unqualified COLUMNS('regexp')
explicit MatcherNode(std::shared_ptr<re2::RE2> columns_matcher_, ColumnTransformersNodes column_transformers_ = {});
/// Variant qualified COLUMNS('regexp')
explicit MatcherNode(Identifier qualified_identifier_, std::shared_ptr<re2::RE2> columns_matcher_, ColumnTransformersNodes column_transformers_ = {});
/// Variant unqualified COLUMNS(column_name_1, ...)
explicit MatcherNode(Identifiers columns_identifiers_, ColumnTransformersNodes column_transformers_ = {});
/// Variant qualified COLUMNS(column_name_1, ...)
explicit MatcherNode(Identifier qualified_identifier_, Identifiers columns_identifiers_, ColumnTransformersNodes column_transformers_ = {});
/// Get matcher type
MatcherNodeType getMatcherType() const
{
return matcher_type;
}
/// Returns true if matcher is asterisk matcher, false otherwise
bool isAsteriskMatcher() const
{
return matcher_type == MatcherNodeType::ASTERISK;
}
/// Returns true if matcher is columns regexp or columns list matcher, false otherwise
bool isColumnsMatcher() const
{
return matcher_type == MatcherNodeType::COLUMNS_REGEXP || matcher_type == MatcherNodeType::COLUMNS_LIST;
}
/// Returns true if matcher is qualified, false otherwise
bool isQualified() const
{
return !qualified_identifier.empty();
}
/// Returns true if matcher is not qualified, false otherwise
bool isUnqualified() const
{
return qualified_identifier.empty();
}
/// Get qualified identifier
const Identifier & getQualifiedIdentifier() const
{
return qualified_identifier;
}
/// Get columns matcher. Valid only if this matcher has type COLUMNS_REGEXP.
const std::shared_ptr<re2::RE2> & getColumnsMatcher() const
{
return columns_matcher;
}
/// Get columns identifiers. Valid only if this matcher has type COLUMNS_LIST.
const Identifiers & getColumnsIdentifiers() const
{
return columns_identifiers;
}
/// Get column transformers
const ListNode & getColumnTransformers() const
{
return children[column_transformers_child_index]->as<const ListNode &>();
}
/// Get column transformers
const QueryTreeNodePtr & getColumnTransformersNode() const
{
return children[column_transformers_child_index];
}
/// Returns true if matcher match column name, false otherwise
bool isMatchingColumn(const std::string & column_name);
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::MATCHER;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
explicit MatcherNode(MatcherNodeType matcher_type_,
Identifier qualified_identifier_,
Identifiers columns_identifiers_,
std::shared_ptr<re2::RE2> columns_matcher_,
ColumnTransformersNodes column_transformers_);
MatcherNodeType matcher_type;
Identifier qualified_identifier;
Identifiers columns_identifiers;
std::shared_ptr<re2::RE2> columns_matcher;
std::unordered_set<std::string> columns_identifiers_set;
static constexpr size_t column_transformers_child_index = 0;
static constexpr size_t children_size = column_transformers_child_index + 1;
};
}

View File

@ -0,0 +1,170 @@
#include <Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Functions/IFunction.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_TYPE_OF_FIELD;
}
namespace
{
Field zeroField(const Field & value)
{
switch (value.getType())
{
case Field::Types::UInt64: return static_cast<UInt64>(0);
case Field::Types::Int64: return static_cast<Int64>(0);
case Field::Types::Float64: return static_cast<Float64>(0);
case Field::Types::UInt128: return static_cast<UInt128>(0);
case Field::Types::Int128: return static_cast<Int128>(0);
case Field::Types::UInt256: return static_cast<UInt256>(0);
case Field::Types::Int256: return static_cast<Int256>(0);
default:
break;
}
throw Exception(ErrorCodes::BAD_TYPE_OF_FIELD, "Unexpected literal type in function");
}
/** Rewrites: sum([multiply|divide]) -> [multiply|divide](sum)
* [min|max|avg]([multiply|divide|plus|minus]) -> [multiply|divide|plus|minus]([min|max|avg])
*
* TODO: Support `groupBitAnd`, `groupBitOr`, `groupBitXor` functions.
* TODO: Support rewrite `f((2 * n) * n)` into '2 * f(n * n)'.
*/
class AggregateFunctionsArithmericOperationsVisitor : public InDepthQueryTreeVisitor<AggregateFunctionsArithmericOperationsVisitor>
{
public:
/// Traverse tree bottom to top
static bool shouldTraverseTopToBottom()
{
return false;
}
static void visitImpl(QueryTreeNodePtr & node)
{
auto * aggregate_function_node = node->as<FunctionNode>();
if (!aggregate_function_node || !aggregate_function_node->isAggregateFunction())
return;
static std::unordered_map<std::string_view, std::unordered_set<std::string_view>> supported_functions
= {{"sum", {"multiply", "divide"}},
{"min", {"multiply", "divide", "plus", "minus"}},
{"max", {"multiply", "divide", "plus", "minus"}},
{"avg", {"multiply", "divide", "plus", "minus"}}};
auto & aggregate_function_arguments_nodes = aggregate_function_node->getArguments().getNodes();
if (aggregate_function_arguments_nodes.size() != 1)
return;
auto * inner_function_node = aggregate_function_arguments_nodes[0]->as<FunctionNode>();
if (!inner_function_node)
return;
auto & inner_function_arguments_nodes = inner_function_node->getArguments().getNodes();
if (inner_function_arguments_nodes.size() != 2)
return;
/// Aggregate functions[sum|min|max|avg] is case-insensitive, so we use lower cases name
auto lower_function_name = Poco::toLower(aggregate_function_node->getFunctionName());
auto supported_function_it = supported_functions.find(lower_function_name);
if (supported_function_it == supported_functions.end())
return;
const auto & inner_function_name = inner_function_node->getFunctionName();
if (!supported_function_it->second.contains(inner_function_name))
return;
auto left_argument_constant_value = inner_function_arguments_nodes[0]->getConstantValueOrNull();
auto right_argument_constant_value = inner_function_arguments_nodes[1]->getConstantValueOrNull();
/** If we extract negative constant, aggregate function name must be updated.
*
* Example: SELECT min(-1 * id);
* Result: SELECT -1 * max(id);
*/
std::string function_name_if_constant_is_negative;
if (inner_function_name == "multiply" || inner_function_name == "divide")
{
if (lower_function_name == "min")
function_name_if_constant_is_negative = "max";
else if (lower_function_name == "max")
function_name_if_constant_is_negative = "min";
}
if (left_argument_constant_value && !right_argument_constant_value)
{
/// Do not rewrite `sum(1/n)` with `sum(1) * div(1/n)` because of lose accuracy
if (inner_function_name == "divide")
return;
/// Rewrite `aggregate_function(inner_function(constant, argument))` into `inner_function(constant, aggregate_function(argument))`
const auto & left_argument_constant_value_literal = left_argument_constant_value->getValue();
if (!function_name_if_constant_is_negative.empty() &&
left_argument_constant_value_literal < zeroField(left_argument_constant_value_literal))
{
resolveAggregateFunctionNode(*aggregate_function_node, function_name_if_constant_is_negative);
}
auto inner_function = aggregate_function_arguments_nodes[0];
auto inner_function_right_argument = std::move(inner_function_arguments_nodes[1]);
aggregate_function_arguments_nodes = {inner_function_right_argument};
inner_function_arguments_nodes[1] = node;
node = std::move(inner_function);
}
else if (right_argument_constant_value)
{
/// Rewrite `aggregate_function(inner_function(argument, constant))` into `inner_function(aggregate_function(argument), constant)`
const auto & right_argument_constant_value_literal = right_argument_constant_value->getValue();
if (!function_name_if_constant_is_negative.empty() &&
right_argument_constant_value_literal < zeroField(right_argument_constant_value_literal))
{
resolveAggregateFunctionNode(*aggregate_function_node, function_name_if_constant_is_negative);
}
auto inner_function = aggregate_function_arguments_nodes[0];
auto inner_function_left_argument = std::move(inner_function_arguments_nodes[0]);
aggregate_function_arguments_nodes = {inner_function_left_argument};
inner_function_arguments_nodes[0] = node;
node = std::move(inner_function);
}
}
private:
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
{
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name,
function_aggregate_function->getArgumentTypes(),
function_aggregate_function->getParameters(),
properties);
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
}
};
}
void AggregateFunctionsArithmericOperationsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
{
AggregateFunctionsArithmericOperationsVisitor visitor;
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,24 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Extract arithmeric operations from aggregate functions.
*
* Example: SELECT sum(a * 2);
* Result: SELECT sum(a) * 2;
*/
class AggregateFunctionsArithmericOperationsPass final : public IQueryTreePass
{
public:
String getName() override { return "AggregateFunctionsArithmericOperations"; }
String getDescription() override { return "Extract arithmeric operations from aggregate functions."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,85 @@
#include <Analyzer/Passes/CountDistinctPass.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/QueryNode.h>
namespace DB
{
namespace
{
class CountDistinctVisitor : public InDepthQueryTreeVisitor<CountDistinctVisitor>
{
public:
static void visitImpl(QueryTreeNodePtr & node)
{
auto * query_node = node->as<QueryNode>();
/// Check that query has only SELECT clause
if (!query_node || (query_node->hasWith() || query_node->hasPrewhere() || query_node->hasWhere() || query_node->hasGroupBy() ||
query_node->hasHaving() || query_node->hasWindow() || query_node->hasOrderBy() || query_node->hasLimitByLimit() || query_node->hasLimitByOffset() ||
query_node->hasLimitBy() || query_node->hasLimit() || query_node->hasOffset()))
return;
/// Check that query has only single table expression
auto join_tree_node_type = query_node->getJoinTree()->getNodeType();
if (join_tree_node_type == QueryTreeNodeType::JOIN || join_tree_node_type == QueryTreeNodeType::ARRAY_JOIN)
return;
/// Check that query has only single node in projection
auto & projection_nodes = query_node->getProjection().getNodes();
if (projection_nodes.size() != 1)
return;
/// Check that query single projection node is `countDistinct` function
auto & projection_node = projection_nodes[0];
auto * function_node = projection_node->as<FunctionNode>();
if (!function_node)
return;
auto lower_function_name = Poco::toLower(function_node->getFunctionName());
if (lower_function_name != "countdistinct" && lower_function_name != "uniqexact")
return;
/// Check that `countDistinct` function has single COLUMN argument
auto & count_distinct_arguments_nodes = function_node->getArguments().getNodes();
if (count_distinct_arguments_nodes.size() != 1 && count_distinct_arguments_nodes[0]->getNodeType() != QueryTreeNodeType::COLUMN)
return;
auto & count_distinct_argument_column = count_distinct_arguments_nodes[0];
auto & count_distinct_argument_column_typed = count_distinct_argument_column->as<ColumnNode &>();
/// Build subquery SELECT count_distinct_argument_column FROM table_expression GROUP BY count_distinct_argument_column
auto subquery = std::make_shared<QueryNode>();
subquery->getJoinTree() = query_node->getJoinTree();
subquery->getProjection().getNodes().push_back(count_distinct_argument_column);
subquery->getGroupBy().getNodes().push_back(count_distinct_argument_column);
subquery->resolveProjectionColumns({count_distinct_argument_column_typed.getColumn()});
/// Put subquery into JOIN TREE of initial query
query_node->getJoinTree() = std::move(subquery);
/// Replace `countDistinct` of initial query into `count`
auto result_type = function_node->getResultType();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(result_type));
function_node->getArguments().getNodes().clear();
}
};
}
void CountDistinctPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
{
CountDistinctVisitor visitor;
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,27 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Optimize single `countDistinct` into `count` over subquery.
*
* Example: SELECT countDistinct(column) FROM table;
* Result: SELECT count() FROM (SELECT column FROM table GROUP BY column);
*/
class CountDistinctPass final : public IQueryTreePass
{
public:
String getName() override { return "CountDistinct"; }
String getDescription() override
{
return "Optimize single countDistinct into count over subquery";
}
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,175 @@
#include <Analyzer/Passes/CustomizeFunctionsPass.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace
{
class CustomizeFunctionsVisitor : public InDepthQueryTreeVisitor<CustomizeFunctionsVisitor>
{
public:
explicit CustomizeFunctionsVisitor(ContextPtr & context_)
: context(context_)
{}
void visitImpl(QueryTreeNodePtr & node) const
{
auto * function_node = node->as<FunctionNode>();
if (!function_node)
return;
const auto & settings = context->getSettingsRef();
/// After successful function replacement function name and function name lowercase must be recalculated
auto function_name = function_node->getFunctionName();
auto function_name_lowercase = Poco::toLower(function_name);
if (function_node->isAggregateFunction() || function_node->isWindowFunction())
{
auto count_distinct_implementation_function_name = String(settings.count_distinct_implementation);
/// Replace countDistinct with countDistinct implementation
if (function_name_lowercase == "countdistinct")
{
resolveAggregateOrWindowFunctionNode(*function_node, count_distinct_implementation_function_name);
function_name = function_node->getFunctionName();
function_name_lowercase = Poco::toLower(function_name);
}
/// Replace countIfDistinct with countDistinctIf implementation
if (function_name_lowercase == "countifdistinct")
{
resolveAggregateOrWindowFunctionNode(*function_node, count_distinct_implementation_function_name + "If");
function_name = function_node->getFunctionName();
function_name_lowercase = Poco::toLower(function_name);
}
/// Replace aggregateFunctionIfDistinct into aggregateFunctionDistinctIf to make execution more optimal
if (function_name_lowercase.ends_with("ifdistinct"))
{
size_t prefix_length = function_name_lowercase.size() - strlen("ifdistinct");
auto updated_function_name = function_name_lowercase.substr(0, prefix_length) + "DistinctIf";
resolveAggregateOrWindowFunctionNode(*function_node, updated_function_name);
function_name = function_node->getFunctionName();
function_name_lowercase = Poco::toLower(function_name);
}
/// Rewrite all aggregate functions to add -OrNull suffix to them
if (settings.aggregate_functions_null_for_empty && !function_name.ends_with("OrNull"))
{
auto function_properies = AggregateFunctionFactory::instance().tryGetProperties(function_name);
if (function_properies && !function_properies->returns_default_when_only_null)
{
auto updated_function_name = function_name + "OrNull";
resolveAggregateOrWindowFunctionNode(*function_node, updated_function_name);
function_name = function_node->getFunctionName();
function_name_lowercase = Poco::toLower(function_name);
}
}
/** Move -OrNull suffix ahead, this should execute after add -OrNull suffix.
* Used to rewrite aggregate functions with -OrNull suffix in some cases.
* Example: sumIfOrNull.
* Result: sumOrNullIf.
*/
if (function_name.ends_with("OrNull"))
{
auto function_properies = AggregateFunctionFactory::instance().tryGetProperties(function_name);
if (function_properies && !function_properies->returns_default_when_only_null)
{
size_t function_name_size = function_name.size();
static constexpr std::array<std::string_view, 4> suffixes_to_replace = {"MergeState", "Merge", "State", "If"};
for (const auto & suffix : suffixes_to_replace)
{
auto suffix_string_value = String(suffix);
auto suffix_to_check = suffix_string_value + "OrNull";
if (!function_name.ends_with(suffix_to_check))
continue;
auto updated_function_name = function_name.substr(0, function_name_size - suffix_to_check.size()) + "OrNull" + suffix_string_value;
resolveAggregateOrWindowFunctionNode(*function_node, updated_function_name);
function_name = function_node->getFunctionName();
function_name_lowercase = Poco::toLower(function_name);
break;
}
}
}
return;
}
if (settings.transform_null_in)
{
auto function_result_type = function_node->getResultType();
static constexpr std::array<std::pair<std::string_view, std::string_view>, 4> in_function_to_replace_null_in_function_map =
{{
{"in", "nullIn"},
{"notin", "notNullIn"},
{"globalin", "globalNullIn"},
{"globalnotin", "globalNotNullIn"},
}};
for (const auto & [in_function_name, in_function_name_to_replace] : in_function_to_replace_null_in_function_map)
{
if (function_name_lowercase == in_function_name)
{
resolveOrdinaryFunctionNode(*function_node, String(in_function_name_to_replace));
function_name = function_node->getFunctionName();
function_name_lowercase = Poco::toLower(function_name);
break;
}
}
}
}
static inline void resolveAggregateOrWindowFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
{
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name,
function_aggregate_function->getArgumentTypes(),
function_aggregate_function->getParameters(),
properties);
if (function_node.isAggregateFunction())
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
else if (function_node.isWindowFunction())
function_node.resolveAsWindowFunction(std::move(aggregate_function), std::move(function_result_type));
}
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{
auto function_result_type = function_node.getResultType();
auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function, std::move(function_result_type));
}
private:
ContextPtr & context;
};
}
void CustomizeFunctionsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
CustomizeFunctionsVisitor visitor(context);
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Customize aggregate functions and `in` functions implementations.
*
* Example: SELECT countDistinct();
* Result: SELECT countDistinctImplementation();
* Function countDistinctImplementation is taken from settings.count_distinct_implementation.
*/
class CustomizeFunctionsPass final : public IQueryTreePass
{
public:
String getName() override { return "CustomizeFunctions"; }
String getDescription() override { return "Customize implementation of aggregate functions, and in functions."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,211 @@
#include <Analyzer/Passes/FunctionToSubcolumnsPass.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeMap.h>
#include <Storages/IStorage.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/TableNode.h>
namespace DB
{
namespace
{
class FunctionToSubcolumnsVisitor : public InDepthQueryTreeVisitor<FunctionToSubcolumnsVisitor>
{
public:
explicit FunctionToSubcolumnsVisitor(ContextPtr & context_)
: context(context_)
{}
void visitImpl(QueryTreeNodePtr & node) const
{
auto * function_node = node->as<FunctionNode>();
if (!function_node)
return;
auto & function_arguments_nodes = function_node->getArguments().getNodes();
size_t function_arguments_nodes_size = function_arguments_nodes.size();
if (function_arguments_nodes.empty() || function_arguments_nodes_size > 2)
return;
auto * first_argument_column_node = function_arguments_nodes.front()->as<ColumnNode>();
if (!first_argument_column_node)
return;
auto column_source = first_argument_column_node->getColumnSource();
auto * table_node = column_source->as<TableNode>();
if (!table_node)
return;
const auto & storage = table_node->getStorage();
if (!storage->supportsSubcolumns())
return;
auto column = first_argument_column_node->getColumn();
WhichDataType column_type(column.type);
const auto & function_name = function_node->getFunctionName();
if (function_arguments_nodes_size == 1)
{
if (column_type.isArray())
{
if (function_name == "length")
{
/// Replace `length(array_argument)` with `array_argument.size0`
column.name += ".size0";
node = std::make_shared<ColumnNode>(column, column_source);
}
else if (function_name == "empty")
{
/// Replace `empty(array_argument)` with `equals(array_argument.size0, 0)`
column.name += ".size0";
column.type = std::make_shared<DataTypeUInt64>();
resolveOrdinaryFunctionNode(*function_node, "equals");
function_arguments_nodes.clear();
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
}
else if (function_name == "notEmpty")
{
/// Replace `notEmpty(array_argument)` with `notEquals(array_argument.size0, 0)`
column.name += ".size0";
column.type = std::make_shared<DataTypeUInt64>();
resolveOrdinaryFunctionNode(*function_node, "notEquals");
function_arguments_nodes.clear();
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
}
}
else if (column_type.isNullable())
{
if (function_name == "isNull")
{
/// Replace `isNull(nullable_argument)` with `nullable_argument.null`
column.name += ".null";
node = std::make_shared<ColumnNode>(column, column_source);
}
else if (function_name == "isNotNull")
{
/// Replace `isNotNull(nullable_argument)` with `not(nullable_argument.null)`
column.name += ".null";
column.type = std::make_shared<DataTypeUInt8>();
resolveOrdinaryFunctionNode(*function_node, "not");
function_arguments_nodes = {std::make_shared<ColumnNode>(column, column_source)};
}
}
else if (column_type.isMap())
{
if (function_name == "mapKeys")
{
/// Replace `mapKeys(map_argument)` with `map_argument.keys`
column.name += ".keys";
column.type = function_node->getResultType();
node = std::make_shared<ColumnNode>(column, column_source);
}
else if (function_name == "mapValues")
{
/// Replace `mapValues(map_argument)` with `map_argument.values`
column.name += ".values";
column.type = function_node->getResultType();
node = std::make_shared<ColumnNode>(column, column_source);
}
}
}
else
{
auto second_argument_constant_value = function_arguments_nodes[1]->getConstantValueOrNull();
if (function_name == "tupleElement" && column_type.isTuple() && second_argument_constant_value)
{
/** Replace `tupleElement(tuple_argument, string_literal)`, `tupleElement(tuple_argument, integer_literal)`
* with `tuple_argument.column_name`.
*/
const auto & tuple_element_constant_value = second_argument_constant_value->getValue();
const auto & tuple_element_constant_value_type = tuple_element_constant_value.getType();
const auto & data_type_tuple = assert_cast<const DataTypeTuple &>(*column.type);
String subcolumn_name;
if (tuple_element_constant_value_type == Field::Types::String)
{
subcolumn_name = tuple_element_constant_value.get<const String &>();
}
else if (tuple_element_constant_value_type == Field::Types::UInt64)
{
auto tuple_column_index = tuple_element_constant_value.get<UInt64>();
subcolumn_name = data_type_tuple.getNameByPosition(tuple_column_index);
}
else
{
return;
}
column.name += '.';
column.name += subcolumn_name;
column.type = function_node->getResultType();
node = std::make_shared<ColumnNode>(column, column_source);
}
else if (function_name == "mapContains" && column_type.isMap())
{
const auto & data_type_map = assert_cast<const DataTypeMap &>(*column.type);
/// Replace `mapContains(map_argument, argument)` with `has(map_argument.keys, argument)`
column.name += ".keys";
column.type = data_type_map.getKeyType();
auto has_function_argument = std::make_shared<ColumnNode>(column, column_source);
resolveOrdinaryFunctionNode(*function_node, "has");
function_arguments_nodes[0] = std::move(has_function_argument);
}
}
}
private:
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{
auto function_result_type = function_node.getResultType();
auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function, std::move(function_result_type));
}
ContextPtr & context;
};
}
void FunctionToSubcolumnsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
FunctionToSubcolumnsVisitor visitor(context);
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,31 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Transform functions to subcolumns.
* It can help to reduce amount of read data.
*
* Example: SELECT tupleElement(column, subcolumn) FROM test_table;
* Result: SELECT column.subcolumn FROM test_table;
*
* Example: SELECT length(array_column) FROM test_table;
* Result: SELECT array_column.size0 FROM test_table;
*
* Example: SELECT nullable_column IS NULL FROM test_table;
* Result: SELECT nullable_column.null FROM test_table;
*/
class FunctionToSubcolumnsPass final : public IQueryTreePass
{
public:
String getName() override { return "FunctionToSubcolumns"; }
String getDescription() override { return "Rewrite function to subcolumns, for example tupleElement(column, subcolumn) into column.subcolumn"; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,75 @@
#include <Analyzer/Passes/IfChainToMultiIfPass.h>
#include <DataTypes/DataTypesNumber.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace
{
class IfChainToMultiIfPassVisitor : public InDepthQueryTreeVisitor<IfChainToMultiIfPassVisitor>
{
public:
explicit IfChainToMultiIfPassVisitor(FunctionOverloadResolverPtr multi_if_function_ptr_)
: multi_if_function_ptr(std::move(multi_if_function_ptr_))
{}
void visitImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || function_node->getFunctionName() != "if" || function_node->getArguments().getNodes().size() != 3)
return;
std::vector<QueryTreeNodePtr> multi_if_arguments;
auto & function_node_arguments = function_node->getArguments().getNodes();
multi_if_arguments.insert(multi_if_arguments.end(), function_node_arguments.begin(), function_node_arguments.end());
QueryTreeNodePtr if_chain_node = multi_if_arguments.back();
while (true)
{
/// Check if last `multiIf` argument is `if` function
auto * if_chain_function_node = if_chain_node->as<FunctionNode>();
if (!if_chain_function_node || if_chain_function_node->getFunctionName() != "if" || if_chain_function_node->getArguments().getNodes().size() != 3)
break;
/// Replace last `multiIf` argument with `if` function arguments
multi_if_arguments.pop_back();
auto & if_chain_function_node_arguments = if_chain_function_node->getArguments().getNodes();
multi_if_arguments.insert(multi_if_arguments.end(), if_chain_function_node_arguments.begin(), if_chain_function_node_arguments.end());
/// Use last `multiIf` argument for next check
if_chain_node = multi_if_arguments.back();
}
/// Do not replace `if` with 3 arguments to `multiIf`
if (multi_if_arguments.size() <= 3)
return;
auto multi_if_function = std::make_shared<FunctionNode>("multiIf");
multi_if_function->resolveAsFunction(multi_if_function_ptr, std::make_shared<DataTypeUInt8>());
multi_if_function->getArguments().getNodes() = std::move(multi_if_arguments);
node = std::move(multi_if_function);
}
private:
FunctionOverloadResolverPtr multi_if_function_ptr;
};
}
void IfChainToMultiIfPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
IfChainToMultiIfPassVisitor visitor(FunctionFactory::instance().get("multiIf", context));
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Convert `if` chain into single `multiIf`.
* Replace if(cond_1, then_1_value, if(cond_2, ...)) chains into multiIf(cond_1, then_1_value, cond_2, ...).
*
* Example: SELECT if(cond_1, then_1_value, if(cond_2, then_2_value, else_value));
* Result: SELECT multiIf(cond_1, then_1_value, cond_2, then_2_value, else_value);
*/
class IfChainToMultiIfPass final : public IQueryTreePass
{
public:
String getName() override { return "IfChainToMultiIf"; }
String getDescription() override { return "Optimize if chain to multiIf"; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,56 @@
#include <Analyzer/Passes/IfConstantConditionPass.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace
{
class IfConstantConditionVisitor : public InDepthQueryTreeVisitor<IfConstantConditionVisitor>
{
public:
static void visitImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || (function_node->getFunctionName() != "if" && function_node->getFunctionName() != "multiIf"))
return;
if (function_node->getArguments().getNodes().size() != 3)
return;
auto & first_argument = function_node->getArguments().getNodes()[0];
auto first_argument_constant_value = first_argument->getConstantValueOrNull();
if (!first_argument_constant_value)
return;
const auto & condition_value = first_argument_constant_value->getValue();
bool condition_boolean_value = false;
if (condition_value.getType() == Field::Types::Int64)
condition_boolean_value = static_cast<bool>(condition_value.safeGet<Int64>());
else if (condition_value.getType() == Field::Types::UInt64)
condition_boolean_value = static_cast<bool>(condition_value.safeGet<UInt64>());
else
return;
if (condition_boolean_value)
node = function_node->getArguments().getNodes()[1];
else
node = function_node->getArguments().getNodes()[2];
}
};
}
void IfConstantConditionPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
{
IfConstantConditionVisitor visitor;
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,28 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Convert `if` with constant condition or `multiIf` with single constant condition into true condition argument value
* or false condition argument value.
*
* Example: SELECT if(1, true_value, false_value);
* Result: SELECT true_value;
*
* Example: SELECT if(0, true_value, false_value);
* Result: SELECT false_value;
*/
class IfConstantConditionPass final : public IQueryTreePass
{
public:
String getName() override { return "IfConstantCondition"; }
String getDescription() override { return "Optimize if, multiIf for constant condition."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,45 @@
#include <Analyzer/Passes/MultiIfToIfPass.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace
{
class MultiIfToIfVisitor : public InDepthQueryTreeVisitor<MultiIfToIfVisitor>
{
public:
explicit MultiIfToIfVisitor(FunctionOverloadResolverPtr if_function_ptr_)
: if_function_ptr(if_function_ptr_)
{}
void visitImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || function_node->getFunctionName() != "multiIf")
return;
if (function_node->getArguments().getNodes().size() != 3)
return;
auto result_type = function_node->getResultType();
function_node->resolveAsFunction(if_function_ptr, std::move(result_type));
}
private:
FunctionOverloadResolverPtr if_function_ptr;
};
}
void MultiIfToIfPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
MultiIfToIfVisitor visitor(FunctionFactory::instance().get("if", context));
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,24 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Convert `multiIf` with single condition into `if`.
*
* Example: SELECT multiIf(x, 1, 0);
* Result: SELECT if(x, 1, 0);
*/
class MultiIfToIfPass final : public IQueryTreePass
{
public:
String getName() override { return "MultiIfToIf"; }
String getDescription() override { return "Optimize multiIf with single condition to if."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,58 @@
#include <Analyzer/Passes/NormalizeCountVariantsPass.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace
{
class NormalizeCountVariantsVisitor : public InDepthQueryTreeVisitor<NormalizeCountVariantsVisitor>
{
public:
static void visitImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || !function_node->isAggregateFunction() || (function_node->getFunctionName() != "count" && function_node->getFunctionName() != "sum"))
return;
if (function_node->getArguments().getNodes().size() != 1)
return;
auto & first_argument = function_node->getArguments().getNodes()[0];
auto first_argument_constant_value = first_argument->getConstantValueOrNull();
if (!first_argument_constant_value)
return;
const auto & first_argument_constant_literal = first_argument_constant_value->getValue();
if (function_node->getFunctionName() == "count" && !first_argument_constant_literal.isNull())
{
function_node->getArguments().getNodes().clear();
}
else if (function_node->getFunctionName() == "sum" && first_argument_constant_literal.getType() == Field::Types::UInt64 &&
first_argument_constant_literal.get<UInt64>() == 1)
{
auto result_type = function_node->getResultType();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(result_type));
function_node->getArguments().getNodes().clear();
}
}
};
}
void NormalizeCountVariantsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
{
NormalizeCountVariantsVisitor visitor;
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,27 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Remove single literal argument from `count`. Convert `sum` with single `1` literal argument into `count`.
*
* Example: SELECT count(1);
* Result: SELECT count();
*
* Example: SELECT sum(1);
* Result: SELECT count();
*/
class NormalizeCountVariantsPass final : public IQueryTreePass
{
public:
String getName() override { return "NormalizeCountVariants"; }
String getDescription() override { return "Optimize count(literal), sum(1) into count()."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,97 @@
#include <Analyzer/Passes/OrderByLimitByDuplicateEliminationPass.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/QueryNode.h>
#include <Analyzer/SortNode.h>
namespace DB
{
namespace
{
struct QueryTreeNodeHash
{
size_t operator()(const IQueryTreeNode * node) const
{
return node->getTreeHash().first;
}
};
struct QueryTreeNodeEqualTo
{
size_t operator()(const IQueryTreeNode * lhs_node, const IQueryTreeNode * rhs_node) const
{
return lhs_node->isEqual(*rhs_node);
}
};
using QueryTreeNodeSet = std::unordered_set<const IQueryTreeNode *, QueryTreeNodeHash, QueryTreeNodeEqualTo>;
class OrderByLimitByDuplicateEliminationVisitor : public InDepthQueryTreeVisitor<OrderByLimitByDuplicateEliminationVisitor>
{
public:
void visitImpl(QueryTreeNodePtr & node)
{
auto * query_node = node->as<QueryNode>();
if (!query_node)
return;
if (query_node->hasOrderBy())
{
QueryTreeNodes result_nodes;
auto & query_order_by_nodes = query_node->getOrderBy().getNodes();
for (auto & sort_node : query_order_by_nodes)
{
auto & sort_node_typed = sort_node->as<SortNode &>();
/// Skip elements with WITH FILL
if (sort_node_typed.withFill())
{
result_nodes.push_back(sort_node);
continue;
}
auto [_, inserted] = unique_expressions_nodes_set.emplace(sort_node_typed.getExpression().get());
if (inserted)
result_nodes.push_back(sort_node);
}
query_order_by_nodes = std::move(result_nodes);
}
unique_expressions_nodes_set.clear();
if (query_node->hasLimitBy())
{
QueryTreeNodes result_nodes;
auto & query_limit_by_nodes = query_node->getLimitBy().getNodes();
for (auto & limit_by_node : query_limit_by_nodes)
{
auto [_, inserted] = unique_expressions_nodes_set.emplace(limit_by_node.get());
if (inserted)
result_nodes.push_back(limit_by_node);
}
query_limit_by_nodes = std::move(result_nodes);
}
}
private:
QueryTreeNodeSet unique_expressions_nodes_set;
};
}
void OrderByLimitByDuplicateEliminationPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
{
OrderByLimitByDuplicateEliminationVisitor visitor;
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,27 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Eliminate duplicate columns from ORDER BY and LIMIT BY.
*
* Example: SELECT * FROM test_table ORDER BY id, id;
* Result: SELECT * FROM test_table ORDER BY id;
*
* Example: SELECT * FROM test_table LIMIT 5 BY id, id;
* Result: SELECT * FROM test_table LIMIT 5 BY id;
*/
class OrderByLimitByDuplicateEliminationPass final : public IQueryTreePass
{
public:
String getName() override { return "OrderByLimitByDuplicateElimination"; }
String getDescription() override { return "Remove duplicate columns from ORDER BY, LIMIT BY."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,59 @@
#include <Analyzer/Passes/OrderByTupleEliminationPass.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/QueryNode.h>
#include <Analyzer/SortNode.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace
{
class OrderByTupleEliminationVisitor : public InDepthQueryTreeVisitor<OrderByTupleEliminationVisitor>
{
public:
static void visitImpl(QueryTreeNodePtr & node)
{
auto * query_node = node->as<QueryNode>();
if (!query_node || !query_node->hasOrderBy())
return;
QueryTreeNodes result_nodes;
for (auto & sort_node : query_node->getOrderBy().getNodes())
{
auto & sort_node_typed = sort_node->as<SortNode &>();
auto * function_expression = sort_node_typed.getExpression()->as<FunctionNode>();
if (sort_node_typed.withFill() || !function_expression || function_expression->getFunctionName() != "tuple")
{
result_nodes.push_back(sort_node);
continue;
}
auto & tuple_arguments_nodes = function_expression->getArguments().getNodes();
for (auto & argument_node : tuple_arguments_nodes)
{
auto result_sort_node = std::make_shared<SortNode>(argument_node,
sort_node_typed.getSortDirection(),
sort_node_typed.getNullsSortDirection(),
sort_node_typed.getCollator());
result_nodes.push_back(std::move(result_sort_node));
}
}
query_node->getOrderBy().getNodes() = std::move(result_nodes);
}
};
}
void OrderByTupleEliminationPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
{
OrderByTupleEliminationVisitor visitor;
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,24 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Eliminate tuples from ORDER BY.
*
* Example: SELECT * FROM test_table ORDER BY (a, b);
* Result: SELECT * FROM test_table ORDER BY a, b;
*/
class OrderByTupleEliminationPass final : public IQueryTreePass
{
public:
String getName() override { return "OrderByTupleElimination"; }
String getDescription() override { return "Remove tuple from ORDER BY."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,96 @@
#pragma once
#include <Parsers/IAST_fwd.h>
#include <Interpreters/Context_fwd.h>
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** This pass make initial query analysis.
*
* 1. All identifiers are resolved. Next passes can expect that there will be no IdentifierNode in query tree.
* 2. All matchers are resolved. Next passes can expect that there will be no MatcherNode in query tree.
* 3. All functions are resolved. Next passes can expect that for each FunctionNode its result type will be set, and it will be resolved
* as aggregate or non aggregate function.
* 4. All lambda expressions that are function arguments are resolved. Next passes can expect that LambaNode expression is resolved, and lambda has concrete arguments.
* 5. All standalone lambda expressions are resolved. Next passes can expect that there will be no standalone LambaNode expressions in query.
* 6. Constants are folded. Example: SELECT plus(1, 1).
* Motivation for this, there are places in query tree that must contain constant:
* Function parameters. Example: SELECT quantile(0.5)(x).
* Functions in which result type depends on constant expression argument. Example: cast(x, 'type_name').
* Expressions that are part of LIMIT BY LIMIT, LIMIT BY OFFSET, LIMIT, OFFSET. Example: SELECT * FROM test_table LIMIT expr.
* Window function window frame OFFSET begin and OFFSET end.
*
* 7. All scalar subqueries are evaluated.
* TODO: Scalar subqueries must be evaluated only if they are part of query tree where we must have constant. This is currently not done
* because execution layer does not support scalar subqueries execution.
*
* 8. For query node.
*
* Projection columns are calculated. Later passes cannot change type, display name of projection column, and cannot add or remove
* columns in projection section.
* WITH and WINDOW sections are removed.
*
* 9. Query is validated. Parts that are validated:
*
* Constness of function parameters.
* Constness of LIMIT and OFFSET.
* Window functions frame. Constness of window functions frame begin OFFSET, end OFFSET.
* In query only columns that are specified in GROUP BY keys after GROUP BY are used.
* GROUPING function arguments are specified in GROUP BY keys.
* No GROUPING function if there is no GROUP BY.
* No aggregate functions in JOIN TREE, WHERE, PREWHERE, GROUP BY and inside another aggregate functions.
* GROUP BY modifiers CUBE, ROLLUP, GROUPING SETS and WITH TOTALS.
* Table expression modifiers are validated for table and table function nodes in JOIN TREE.
* Table expression modifiers are disabled for subqueries in JOIN TREE.
* For JOIN, ARRAY JOIN subqueries and table functions must have alias (Can be changed using joined_subquery_requires_alias setting).
*
* 10. Special functions handling:
* Function `untuple` is handled properly.
* Function `arrayJoin` is handled properly.
* For functions `dictGet` and its variations and for function `joinGet` identifier as first argument is handled properly.
* Function `exists` is converted into `in`.
*
* For function `grouping` arguments are resolved, but it is planner responsibility to initialize it with concrete grouping function
* based on group by kind and group by keys positions.
*
* For function `in` and its variations arguments are resolved, but sets are not build.
* If left and right arguments are constants constant folding is performed.
* If right argument resolved as table, and table is not of type Set, it is replaced with query that read only ordinary columns from underlying
* storage.
* Example: SELECT id FROM test_table WHERE id IN test_table_other;
* Result: SELECT id FROM test_table WHERE id IN (SELECT test_table_column FROM test_table_other);
*/
class QueryAnalysisPass final : public IQueryTreePass
{
public:
/** Construct query analysis pass for query or union analysis.
* Available columns are extracted from query node join tree.
*/
QueryAnalysisPass() = default;
/** Construct query analysis pass for expression or list of expressions analysis.
* Available expression columns are extracted from table expression.
* Table expression node must have query, union, table, table function type.
*/
explicit QueryAnalysisPass(QueryTreeNodePtr table_expression_);
String getName() override
{
return "QueryAnalysis";
}
String getDescription() override
{
return "Resolve type for each query expression. Replace identifiers, matchers with query expressions. Perform constant folding. Evaluate scalar subqueries.";
}
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
private:
QueryTreeNodePtr table_expression;
};
}

View File

@ -0,0 +1,157 @@
#include <Analyzer/Passes/SumIfToCountIfPass.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeNullable.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace
{
class SumIfToCountIfVisitor : public InDepthQueryTreeVisitor<SumIfToCountIfVisitor>
{
public:
explicit SumIfToCountIfVisitor(ContextPtr & context_)
: context(context_)
{}
void visitImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || !function_node->isAggregateFunction())
return;
auto function_name = function_node->getFunctionName();
auto lower_function_name = Poco::toLower(function_name);
/// sumIf, SumIf or sUMIf are valid function names, but sumIF or sumiF are not
if (lower_function_name != "sum" && (lower_function_name != "sumif" || !function_name.ends_with("If")))
return;
auto & function_node_arguments_nodes = function_node->getArguments().getNodes();
/// Rewrite `sumIf(1, cond)` into `countIf(cond)`
if (lower_function_name == "sumif")
{
if (function_node_arguments_nodes.size() != 2)
return;
auto constant_value = function_node_arguments_nodes[0]->getConstantValueOrNull();
if (!constant_value)
return;
const auto & constant_value_literal = constant_value->getValue();
if (!isInt64OrUInt64FieldType(constant_value_literal.getType()))
return;
if (constant_value_literal.get<UInt64>() != 1)
return;
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);
function_node_arguments_nodes.resize(1);
resolveAggregateFunctionNode(*function_node, "countIf");
return;
}
/** Rewrite `sum(if(cond, 1, 0))` into `countIf(cond)`.
* Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))`.
*/
if (function_node_arguments_nodes.size() != 1)
return;
auto & nested_argument = function_node_arguments_nodes[0];
auto * nested_function = nested_argument->as<FunctionNode>();
if (!nested_function || nested_function->getFunctionName() != "if")
return;
auto & nested_if_function_arguments_nodes = nested_function->getArguments().getNodes();
if (nested_if_function_arguments_nodes.size() != 3)
return;
auto if_true_condition_constant_value = nested_if_function_arguments_nodes[1]->getConstantValueOrNull();
auto if_false_condition_constant_value = nested_if_function_arguments_nodes[2]->getConstantValueOrNull();
if (!if_true_condition_constant_value || !if_false_condition_constant_value)
return;
const auto & if_true_condition_constant_value_literal = if_true_condition_constant_value->getValue();
const auto & if_false_condition_constant_value_literal = if_false_condition_constant_value->getValue();
if (!isInt64OrUInt64FieldType(if_true_condition_constant_value_literal.getType()) ||
!isInt64OrUInt64FieldType(if_false_condition_constant_value_literal.getType()))
return;
auto if_true_condition_value = if_true_condition_constant_value_literal.get<UInt64>();
auto if_false_condition_value = if_false_condition_constant_value_literal.get<UInt64>();
/// Rewrite `sum(if(cond, 1, 0))` into `countIf(cond)`.
if (if_true_condition_value == 1 && if_false_condition_value == 0)
{
function_node_arguments_nodes[0] = std::move(nested_if_function_arguments_nodes[0]);
function_node_arguments_nodes.resize(1);
resolveAggregateFunctionNode(*function_node, "countIf");
return;
}
/// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))`.
if (if_true_condition_value == 0 && if_false_condition_value == 1)
{
auto condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
DataTypePtr not_function_result_type = std::make_shared<DataTypeUInt8>();
if (condition_result_type->isNullable())
not_function_result_type = makeNullable(not_function_result_type);
auto not_function = std::make_shared<FunctionNode>("not");
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context), std::move(not_function_result_type));
auto & not_function_arguments = not_function->getArguments().getNodes();
not_function_arguments.push_back(std::move(nested_if_function_arguments_nodes[0]));
function_node_arguments_nodes[0] = std::move(not_function);
function_node_arguments_nodes.resize(1);
resolveAggregateFunctionNode(*function_node, "countIf");
return;
}
}
private:
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
{
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name,
function_aggregate_function->getArgumentTypes(),
function_aggregate_function->getParameters(),
properties);
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
}
ContextPtr & context;
};
}
void SumIfToCountIfPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
SumIfToCountIfVisitor visitor(context);
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Rewrite `sum(if(cond, value_1, value_2))` and `sumIf` functions to `countIf`.
*
* Example: SELECT sumIf(1, cond);
* Result: SELECT countIf(cond);
*
* Example: SELECT sum(if(cond, 1, 0));
* Result: SELECT countIf(cond);
*
* Example: SELECT sum(if(cond, 0, 1));
* Result: SELECT countIf(not(cond));
*/
class SumIfToCountIfPass final : public IQueryTreePass
{
public:
String getName() override { return "SumIfToCountIf"; }
String getDescription() override { return "Rewrite sum(if) and sumIf into countIf"; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -0,0 +1,64 @@
#include <Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.h>
#include <Functions/IFunction.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace
{
bool isUniqFunction(const String & function_name)
{
return function_name == "uniq" ||
function_name == "uniqExact" ||
function_name == "uniqHLL12" ||
function_name == "uniqCombined" ||
function_name == "uniqCombined64" ||
function_name == "uniqTheta";
}
class UniqInjectiveFunctionsEliminationVisitor : public InDepthQueryTreeVisitor<UniqInjectiveFunctionsEliminationVisitor>
{
public:
static void visitImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || !function_node->isAggregateFunction() || !isUniqFunction(function_node->getFunctionName()))
return;
auto & uniq_function_arguments_nodes = function_node->getArguments().getNodes();
for (auto & uniq_function_argument_node : uniq_function_arguments_nodes)
{
auto * uniq_function_argument_node_typed = uniq_function_argument_node->as<FunctionNode>();
if (!uniq_function_argument_node_typed || !uniq_function_argument_node_typed->isOrdinaryFunction())
continue;
auto & uniq_function_argument_node_argument_nodes = uniq_function_argument_node_typed->getArguments().getNodes();
/// Do not apply optimization if injective function contains multiple arguments
if (uniq_function_argument_node_argument_nodes.size() != 1)
continue;
const auto & uniq_function_argument_node_function = uniq_function_argument_node_typed->getFunction();
if (!uniq_function_argument_node_function->isInjective({}))
continue;
/// Replace injective function with its single argument
uniq_function_argument_node = uniq_function_argument_node_argument_nodes[0];
}
}
};
}
void UniqInjectiveFunctionsEliminationPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
{
UniqInjectiveFunctionsEliminationVisitor visitor;
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,24 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/** Remove injective functions from `uniq*` functions arguments.
*
* Example: SELECT uniq(injectiveFunction(argument));
* Result: SELECT uniq(argument);
*/
class UniqInjectiveFunctionsEliminationPass final : public IQueryTreePass
{
public:
String getName() override { return "UniqInjectiveFunctionsElimination"; }
String getDescription() override { return "Remove injective functions from uniq functions arguments."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

456
src/Analyzer/QueryNode.cpp Normal file
View File

@ -0,0 +1,456 @@
#include <Analyzer/QueryNode.h>
#include <Common/SipHash.h>
#include <Common/FieldVisitorToString.h>
#include <Core/NamesAndTypes.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Analyzer/Utils.h>
namespace DB
{
QueryNode::QueryNode()
: IQueryTreeNode(children_size)
{
children[with_child_index] = std::make_shared<ListNode>();
children[projection_child_index] = std::make_shared<ListNode>();
children[group_by_child_index] = std::make_shared<ListNode>();
children[window_child_index] = std::make_shared<ListNode>();
children[order_by_child_index] = std::make_shared<ListNode>();
children[limit_by_child_index] = std::make_shared<ListNode>();
}
String QueryNode::getName() const
{
WriteBufferFromOwnString buffer;
if (hasWith())
{
buffer << getWith().getName();
buffer << ' ';
}
buffer << "SELECT ";
buffer << getProjection().getName();
if (getJoinTree())
{
buffer << " FROM ";
buffer << getJoinTree()->getName();
}
if (getPrewhere())
{
buffer << " PREWHERE ";
buffer << getPrewhere()->getName();
}
if (getWhere())
{
buffer << " WHERE ";
buffer << getWhere()->getName();
}
if (hasGroupBy())
{
buffer << " GROUP BY ";
buffer << getGroupBy().getName();
}
if (hasHaving())
{
buffer << " HAVING ";
buffer << getHaving()->getName();
}
if (hasWindow())
{
buffer << " WINDOW ";
buffer << getWindow().getName();
}
if (hasOrderBy())
{
buffer << " ORDER BY ";
buffer << getOrderByNode()->getName();
}
if (hasInterpolate())
{
buffer << " INTERPOLATE ";
buffer << getInterpolate()->getName();
}
if (hasLimitByLimit())
{
buffer << "LIMIT ";
buffer << getLimitByLimit()->getName();
}
if (hasLimitByOffset())
{
buffer << "OFFSET ";
buffer << getLimitByOffset()->getName();
}
if (hasLimitBy())
{
buffer << " BY ";
buffer << getLimitBy().getName();
}
if (hasLimit())
{
buffer << " LIMIT ";
buffer << getLimit()->getName();
}
if (hasOffset())
{
buffer << " OFFSET ";
buffer << getOffset()->getName();
}
return buffer.str();
}
void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "QUERY id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
if (is_subquery)
buffer << ", is_subquery: " << is_subquery;
if (is_cte)
buffer << ", is_cte: " << is_cte;
if (is_distinct)
buffer << ", is_distinct: " << is_distinct;
if (is_limit_with_ties)
buffer << ", is_limit_with_ties: " << is_limit_with_ties;
if (is_group_by_with_totals)
buffer << ", is_group_by_with_totals: " << is_group_by_with_totals;
std::string group_by_type;
if (is_group_by_with_rollup)
group_by_type = "rollup";
else if (is_group_by_with_cube)
group_by_type = "cube";
else if (is_group_by_with_grouping_sets)
group_by_type = "grouping_sets";
if (!group_by_type.empty())
buffer << ", group_by_type: " << group_by_type;
if (!cte_name.empty())
buffer << ", cte_name: " << cte_name;
if (constant_value)
{
buffer << ", constant_value: " << constant_value->getValue().dump();
buffer << ", constant_value_type: " << constant_value->getType()->getName();
}
if (table_expression_modifiers)
{
buffer << ", ";
table_expression_modifiers->dump(buffer);
}
if (hasWith())
{
buffer << '\n' << std::string(indent + 2, ' ') << "WITH\n";
getWith().dumpTreeImpl(buffer, format_state, indent + 4);
}
if (!projection_columns.empty())
{
buffer << '\n';
buffer << std::string(indent + 2, ' ') << "PROJECTION COLUMNS\n";
size_t projection_columns_size = projection_columns.size();
for (size_t i = 0; i < projection_columns_size; ++i)
{
const auto & projection_column = projection_columns[i];
buffer << std::string(indent + 4, ' ') << projection_column.name << " " << projection_column.type->getName();
if (i + 1 != projection_columns_size)
buffer << '\n';
}
}
buffer << '\n';
buffer << std::string(indent + 2, ' ') << "PROJECTION\n";
getProjection().dumpTreeImpl(buffer, format_state, indent + 4);
if (getJoinTree())
{
buffer << '\n' << std::string(indent + 2, ' ') << "JOIN TREE\n";
getJoinTree()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (getPrewhere())
{
buffer << '\n' << std::string(indent + 2, ' ') << "PREWHERE\n";
getPrewhere()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (getWhere())
{
buffer << '\n' << std::string(indent + 2, ' ') << "WHERE\n";
getWhere()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasGroupBy())
{
buffer << '\n' << std::string(indent + 2, ' ') << "GROUP BY\n";
getGroupBy().dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasHaving())
{
buffer << '\n' << std::string(indent + 2, ' ') << "HAVING\n";
getHaving()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasWindow())
{
buffer << '\n' << std::string(indent + 2, ' ') << "WINDOW\n";
getWindow().dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasOrderBy())
{
buffer << '\n' << std::string(indent + 2, ' ') << "ORDER BY\n";
getOrderBy().dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasInterpolate())
{
buffer << '\n' << std::string(indent + 2, ' ') << "INTERPOLATE\n";
getInterpolate()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasLimitByLimit())
{
buffer << '\n' << std::string(indent + 2, ' ') << "LIMIT BY LIMIT\n";
getLimitByLimit()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasLimitByOffset())
{
buffer << '\n' << std::string(indent + 2, ' ') << "LIMIT BY OFFSET\n";
getLimitByOffset()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasLimitBy())
{
buffer << '\n' << std::string(indent + 2, ' ') << "LIMIT BY\n";
getLimitBy().dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasLimit())
{
buffer << '\n' << std::string(indent + 2, ' ') << "LIMIT\n";
getLimit()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasOffset())
{
buffer << '\n' << std::string(indent + 2, ' ') << "OFFSET\n";
getOffset()->dumpTreeImpl(buffer, format_state, indent + 4);
}
}
bool QueryNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const QueryNode &>(rhs);
if (constant_value && rhs_typed.constant_value && *constant_value != *rhs_typed.constant_value)
return false;
else if (constant_value && !rhs_typed.constant_value)
return false;
else if (!constant_value && rhs_typed.constant_value)
return false;
if (table_expression_modifiers && rhs_typed.table_expression_modifiers && table_expression_modifiers != rhs_typed.table_expression_modifiers)
return false;
else if (table_expression_modifiers && !rhs_typed.table_expression_modifiers)
return false;
else if (!table_expression_modifiers && rhs_typed.table_expression_modifiers)
return false;
return is_subquery == rhs_typed.is_subquery &&
is_cte == rhs_typed.is_cte &&
cte_name == rhs_typed.cte_name &&
projection_columns == rhs_typed.projection_columns &&
is_distinct == rhs_typed.is_distinct &&
is_limit_with_ties == rhs_typed.is_limit_with_ties &&
is_group_by_with_totals == rhs_typed.is_group_by_with_totals &&
is_group_by_with_rollup == rhs_typed.is_group_by_with_rollup &&
is_group_by_with_cube == rhs_typed.is_group_by_with_cube &&
is_group_by_with_grouping_sets == rhs_typed.is_group_by_with_grouping_sets;
}
void QueryNode::updateTreeHashImpl(HashState & state) const
{
state.update(is_subquery);
state.update(is_cte);
state.update(cte_name.size());
state.update(cte_name);
state.update(projection_columns.size());
for (const auto & projection_column : projection_columns)
{
state.update(projection_column.name.size());
state.update(projection_column.name);
auto projection_column_type_name = projection_column.type->getName();
state.update(projection_column_type_name.size());
state.update(projection_column_type_name);
}
state.update(is_distinct);
state.update(is_limit_with_ties);
state.update(is_group_by_with_totals);
state.update(is_group_by_with_rollup);
state.update(is_group_by_with_cube);
state.update(is_group_by_with_grouping_sets);
if (constant_value)
{
auto constant_dump = applyVisitor(FieldVisitorToString(), constant_value->getValue());
state.update(constant_dump.size());
state.update(constant_dump);
auto constant_value_type_name = constant_value->getType()->getName();
state.update(constant_value_type_name.size());
state.update(constant_value_type_name);
}
if (table_expression_modifiers)
table_expression_modifiers->updateTreeHash(state);
}
QueryTreeNodePtr QueryNode::cloneImpl() const
{
auto result_query_node = std::make_shared<QueryNode>();
result_query_node->is_subquery = is_subquery;
result_query_node->is_cte = is_cte;
result_query_node->is_distinct = is_distinct;
result_query_node->is_limit_with_ties = is_limit_with_ties;
result_query_node->is_group_by_with_totals = is_group_by_with_totals;
result_query_node->is_group_by_with_rollup = is_group_by_with_rollup;
result_query_node->is_group_by_with_cube = is_group_by_with_cube;
result_query_node->is_group_by_with_grouping_sets = is_group_by_with_grouping_sets;
result_query_node->cte_name = cte_name;
result_query_node->projection_columns = projection_columns;
result_query_node->constant_value = constant_value;
result_query_node->table_expression_modifiers = table_expression_modifiers;
return result_query_node;
}
ASTPtr QueryNode::toASTImpl() const
{
auto select_query = std::make_shared<ASTSelectQuery>();
select_query->distinct = is_distinct;
select_query->limit_with_ties = is_limit_with_ties;
select_query->group_by_with_totals = is_group_by_with_totals;
select_query->group_by_with_rollup = is_group_by_with_rollup;
select_query->group_by_with_cube = is_group_by_with_cube;
select_query->group_by_with_grouping_sets = is_group_by_with_grouping_sets;
if (hasWith())
select_query->setExpression(ASTSelectQuery::Expression::WITH, getWith().toAST());
select_query->setExpression(ASTSelectQuery::Expression::SELECT, getProjection().toAST());
ASTPtr tables_in_select_query_ast = std::make_shared<ASTTablesInSelectQuery>();
addTableExpressionOrJoinIntoTablesInSelectQuery(tables_in_select_query_ast, getJoinTree());
select_query->setExpression(ASTSelectQuery::Expression::TABLES, std::move(tables_in_select_query_ast));
if (getPrewhere())
select_query->setExpression(ASTSelectQuery::Expression::PREWHERE, getPrewhere()->toAST());
if (getWhere())
select_query->setExpression(ASTSelectQuery::Expression::WHERE, getWhere()->toAST());
if (hasGroupBy())
select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, getGroupBy().toAST());
if (hasHaving())
select_query->setExpression(ASTSelectQuery::Expression::HAVING, getHaving()->toAST());
if (hasWindow())
select_query->setExpression(ASTSelectQuery::Expression::WINDOW, getWindow().toAST());
if (hasOrderBy())
select_query->setExpression(ASTSelectQuery::Expression::ORDER_BY, getOrderBy().toAST());
if (hasInterpolate())
select_query->setExpression(ASTSelectQuery::Expression::INTERPOLATE, getInterpolate()->toAST());
if (hasLimitByLimit())
select_query->setExpression(ASTSelectQuery::Expression::LIMIT_BY_LENGTH, getLimitByLimit()->toAST());
if (hasLimitByOffset())
select_query->setExpression(ASTSelectQuery::Expression::LIMIT_BY_OFFSET, getLimitByOffset()->toAST());
if (hasLimitBy())
select_query->setExpression(ASTSelectQuery::Expression::LIMIT_BY, getLimitBy().toAST());
if (hasLimit())
select_query->setExpression(ASTSelectQuery::Expression::LIMIT_LENGTH, getLimit()->toAST());
if (hasOffset())
select_query->setExpression(ASTSelectQuery::Expression::LIMIT_OFFSET, getOffset()->toAST());
if (hasSettingsChanges())
{
auto settings_query = std::make_shared<ASTSetQuery>();
settings_query->changes = settings_changes;
select_query->setExpression(ASTSelectQuery::Expression::SETTINGS, std::move(settings_query));
}
auto result_select_query = std::make_shared<ASTSelectWithUnionQuery>();
result_select_query->union_mode = SelectUnionMode::UNION_DEFAULT;
auto list_of_selects = std::make_shared<ASTExpressionList>();
list_of_selects->children.push_back(std::move(select_query));
result_select_query->children.push_back(std::move(list_of_selects));
result_select_query->list_of_selects = result_select_query->children.back();
if (is_subquery)
{
auto subquery = std::make_shared<ASTSubquery>();
subquery->cte_name = cte_name;
subquery->children.push_back(std::move(result_select_query));
return subquery;
}
return result_select_query;
}
}

628
src/Analyzer/QueryNode.h Normal file
View File

@ -0,0 +1,628 @@
#pragma once
#include <Common/SettingsChanges.h>
#include <Core/NamesAndTypes.h>
#include <Core/Field.h>
#include <Analyzer/Identifier.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Analyzer/TableExpressionModifiers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
}
/** Query node represents query in query tree.
*
* Example: SELECT * FROM test_table WHERE id == 0;
* Example: SELECT * FROM test_table_1 AS t1 INNER JOIN test_table_2 AS t2 ON t1.id = t2.id;
*
* Query node consists of following sections.
* 1. WITH section.
* 2. PROJECTION section.
* 3. JOIN TREE section.
* Example: SELECT * FROM test_table_1 AS t1 INNER JOIN test_table_2 AS t2 ON t1.id = t2.id;
* test_table_1 AS t1 INNER JOIN test_table_2 AS t2 ON t1.id = t2.id - JOIN TREE section.
* 4. PREWHERE section.
* 5. WHERE section.
* 6. GROUP BY section.
* 7. HAVING section.
* 8. WINDOW section.
* Example: SELECT * FROM test_table WINDOW window AS (PARTITION BY id);
* 9. ORDER BY section.
* 10. INTERPOLATE section.
* Example: SELECT * FROM test_table ORDER BY id WITH FILL INTERPOLATE (value AS value + 1);
* value AS value + 1 - INTERPOLATE section.
* 11. LIMIT BY limit section.
* 12. LIMIT BY offset section.
* 13. LIMIT BY section.
* Example: SELECT * FROM test_table LIMIT 1 AS a OFFSET 5 AS b BY id, value;
* 1 AS a - LIMIT BY limit section.
* 5 AS b - LIMIT BY offset section.
* id, value - LIMIT BY section.
* 14. LIMIT section.
* 15. OFFSET section.
*
* Query node contains settings changes that must be applied before query analysis or execution.
* Example: SELECT * FROM test_table SETTINGS prefer_column_name_to_alias = 1, join_use_nulls = 1;
*
* Query node can be used as CTE.
* Example: WITH cte_subquery AS (SELECT 1) SELECT * FROM cte_subquery;
*
* Query node can be used as scalar subquery.
* Example: SELECT (SELECT 1) AS scalar_subquery.
*
* During query analysis pass query node must be resolved with projection columns.
*/
class QueryNode;
using QueryNodePtr = std::shared_ptr<QueryNode>;
class QueryNode final : public IQueryTreeNode
{
public:
explicit QueryNode();
/// Returns true if query node is subquery, false otherwise
bool isSubquery() const
{
return is_subquery;
}
/// Set query node is subquery value
void setIsSubquery(bool is_subquery_value)
{
is_subquery = is_subquery_value;
}
/// Returns true if query node is CTE, false otherwise
bool isCTE() const
{
return is_cte;
}
/// Set query node is CTE
void setIsCTE(bool is_cte_value)
{
is_cte = is_cte_value;
}
/// Get query node CTE name
const std::string & getCTEName() const
{
return cte_name;
}
/// Set query node CTE name
void setCTEName(std::string cte_name_value)
{
cte_name = std::move(cte_name_value);
}
/// Returns true if query node has DISTINCT, false otherwise
bool isDistinct() const
{
return is_distinct;
}
/// Set query node DISTINCT value
void setIsDistinct(bool is_distinct_value)
{
is_distinct = is_distinct_value;
}
/// Returns true if query node has LIMIT WITH TIES, false otherwise
bool isLimitWithTies() const
{
return is_limit_with_ties;
}
/// Set query node LIMIT WITH TIES value
void setIsLimitWithTies(bool is_limit_with_ties_value)
{
is_limit_with_ties = is_limit_with_ties_value;
}
/// Returns true, if query node has GROUP BY WITH TOTALS, false otherwise
bool isGroupByWithTotals() const
{
return is_group_by_with_totals;
}
/// Set query node GROUP BY WITH TOTALS value
void setIsGroupByWithTotals(bool is_group_by_with_totals_value)
{
is_group_by_with_totals = is_group_by_with_totals_value;
}
/// Returns true, if query node has GROUP BY with ROLLUP modifier, false otherwise
bool isGroupByWithRollup() const
{
return is_group_by_with_rollup;
}
/// Set query node GROUP BY with ROLLUP modifier value
void setIsGroupByWithRollup(bool is_group_by_with_rollup_value)
{
is_group_by_with_rollup = is_group_by_with_rollup_value;
}
/// Returns true, if query node has GROUP BY with CUBE modifier, false otherwise
bool isGroupByWithCube() const
{
return is_group_by_with_cube;
}
/// Set query node GROUP BY with CUBE modifier value
void setIsGroupByWithCube(bool is_group_by_with_cube_value)
{
is_group_by_with_cube = is_group_by_with_cube_value;
}
/// Returns true, if query node has GROUP BY with GROUPING SETS modifier, false otherwise
bool isGroupByWithGroupingSets() const
{
return is_group_by_with_grouping_sets;
}
/// Set query node GROUP BY with GROUPING SETS modifier value
void setIsGroupByWithGroupingSets(bool is_group_by_with_grouping_sets_value)
{
is_group_by_with_grouping_sets = is_group_by_with_grouping_sets_value;
}
/// Return true if query node has table expression modifiers, false otherwise
bool hasTableExpressionModifiers() const
{
return table_expression_modifiers.has_value();
}
/// Get table expression modifiers
const std::optional<TableExpressionModifiers> & getTableExpressionModifiers() const
{
return table_expression_modifiers;
}
/// Set table expression modifiers
void setTableExpressionModifiers(TableExpressionModifiers table_expression_modifiers_value)
{
table_expression_modifiers = std::move(table_expression_modifiers_value);
}
/// Returns true if query node WITH section is not empty, false otherwise
bool hasWith() const
{
return !getWith().getNodes().empty();
}
/// Get WITH section
const ListNode & getWith() const
{
return children[with_child_index]->as<const ListNode &>();
}
/// Get WITH section
ListNode & getWith()
{
return children[with_child_index]->as<ListNode &>();
}
/// Get WITH section node
const QueryTreeNodePtr & getWithNode() const
{
return children[with_child_index];
}
/// Get WITH section node
QueryTreeNodePtr & getWithNode()
{
return children[with_child_index];
}
/// Get PROJECTION section
const ListNode & getProjection() const
{
return children[projection_child_index]->as<const ListNode &>();
}
/// Get PROJECTION section
ListNode & getProjection()
{
return children[projection_child_index]->as<ListNode &>();
}
/// Get PROJECTION section node
const QueryTreeNodePtr & getProjectionNode() const
{
return children[projection_child_index];
}
/// Get PROJECTION section node
QueryTreeNodePtr & getProjectionNode()
{
return children[projection_child_index];
}
/// Get JOIN TREE section node
const QueryTreeNodePtr & getJoinTree() const
{
return children[join_tree_child_index];
}
/// Get JOIN TREE section node
QueryTreeNodePtr & getJoinTree()
{
return children[join_tree_child_index];
}
/// Returns true if query node PREWHERE section is not empty, false otherwise
bool hasPrewhere() const
{
return children[prewhere_child_index] != nullptr;
}
/// Get PREWHERE section node
const QueryTreeNodePtr & getPrewhere() const
{
return children[prewhere_child_index];
}
/// Get PREWHERE section node
QueryTreeNodePtr & getPrewhere()
{
return children[prewhere_child_index];
}
/// Returns true if query node WHERE section is not empty, false otherwise
bool hasWhere() const
{
return children[where_child_index] != nullptr;
}
/// Get WHERE section node
const QueryTreeNodePtr & getWhere() const
{
return children[where_child_index];
}
/// Get WHERE section node
QueryTreeNodePtr & getWhere()
{
return children[where_child_index];
}
/// Returns true if query node GROUP BY section is not empty, false otherwise
bool hasGroupBy() const
{
return !getGroupBy().getNodes().empty();
}
/// Get GROUP BY section
const ListNode & getGroupBy() const
{
return children[group_by_child_index]->as<const ListNode &>();
}
/// Get GROUP BY section
ListNode & getGroupBy()
{
return children[group_by_child_index]->as<ListNode &>();
}
/// Get GROUP BY section node
const QueryTreeNodePtr & getGroupByNode() const
{
return children[group_by_child_index];
}
/// Get GROUP BY section node
QueryTreeNodePtr & getGroupByNode()
{
return children[group_by_child_index];
}
/// Returns true if query node HAVING section is not empty, false otherwise
bool hasHaving() const
{
return getHaving() != nullptr;
}
/// Get HAVING section node
const QueryTreeNodePtr & getHaving() const
{
return children[having_child_index];
}
/// Get HAVING section node
QueryTreeNodePtr & getHaving()
{
return children[having_child_index];
}
/// Returns true if query node WINDOW section is not empty, false otherwise
bool hasWindow() const
{
return !getWindow().getNodes().empty();
}
/// Get WINDOW section
const ListNode & getWindow() const
{
return children[window_child_index]->as<const ListNode &>();
}
/// Get WINDOW section
ListNode & getWindow()
{
return children[window_child_index]->as<ListNode &>();
}
/// Get WINDOW section node
const QueryTreeNodePtr & getWindowNode() const
{
return children[window_child_index];
}
/// Get WINDOW section node
QueryTreeNodePtr & getWindowNode()
{
return children[window_child_index];
}
/// Returns true if query node ORDER BY section is not empty, false otherwise
bool hasOrderBy() const
{
return !getOrderBy().getNodes().empty();
}
/// Get ORDER BY section
const ListNode & getOrderBy() const
{
return children[order_by_child_index]->as<const ListNode &>();
}
/// Get ORDER BY section
ListNode & getOrderBy()
{
return children[order_by_child_index]->as<ListNode &>();
}
/// Get ORDER BY section node
const QueryTreeNodePtr & getOrderByNode() const
{
return children[order_by_child_index];
}
/// Get ORDER BY section node
QueryTreeNodePtr & getOrderByNode()
{
return children[order_by_child_index];
}
/// Returns true if query node INTERPOLATE section is not empty, false otherwise
bool hasInterpolate() const
{
return getInterpolate() != nullptr;
}
/// Get INTERPOLATE section node
const QueryTreeNodePtr & getInterpolate() const
{
return children[interpolate_child_index];
}
/// Get INTERPOLATE section node
QueryTreeNodePtr & getInterpolate()
{
return children[interpolate_child_index];
}
/// Returns true if query node LIMIT BY LIMIT section is not empty, false otherwise
bool hasLimitByLimit() const
{
return children[limit_by_limit_child_index] != nullptr;
}
/// Get LIMIT BY LIMIT section node
const QueryTreeNodePtr & getLimitByLimit() const
{
return children[limit_by_limit_child_index];
}
/// Get LIMIT BY LIMIT section node
QueryTreeNodePtr & getLimitByLimit()
{
return children[limit_by_limit_child_index];
}
/// Returns true if query node LIMIT BY OFFSET section is not empty, false otherwise
bool hasLimitByOffset() const
{
return children[limit_by_offset_child_index] != nullptr;
}
/// Get LIMIT BY OFFSET section node
const QueryTreeNodePtr & getLimitByOffset() const
{
return children[limit_by_offset_child_index];
}
/// Get LIMIT BY OFFSET section node
QueryTreeNodePtr & getLimitByOffset()
{
return children[limit_by_offset_child_index];
}
/// Returns true if query node LIMIT BY section is not empty, false otherwise
bool hasLimitBy() const
{
return !getLimitBy().getNodes().empty();
}
/// Get LIMIT BY section
const ListNode & getLimitBy() const
{
return children[limit_by_child_index]->as<const ListNode &>();
}
/// Get LIMIT BY section
ListNode & getLimitBy()
{
return children[limit_by_child_index]->as<ListNode &>();
}
/// Get LIMIT BY section node
const QueryTreeNodePtr & getLimitByNode() const
{
return children[limit_by_child_index];
}
/// Get LIMIT BY section node
QueryTreeNodePtr & getLimitByNode()
{
return children[limit_by_child_index];
}
/// Returns true if query node LIMIT section is not empty, false otherwise
bool hasLimit() const
{
return children[limit_child_index] != nullptr;
}
/// Get LIMIT section node
const QueryTreeNodePtr & getLimit() const
{
return children[limit_child_index];
}
/// Get LIMIT section node
QueryTreeNodePtr & getLimit()
{
return children[limit_child_index];
}
/// Returns true if query node OFFSET section is not empty, false otherwise
bool hasOffset() const
{
return children[offset_child_index] != nullptr;
}
/// Get OFFSET section node
const QueryTreeNodePtr & getOffset() const
{
return children[offset_child_index];
}
/// Get OFFSET section node
QueryTreeNodePtr & getOffset()
{
return children[offset_child_index];
}
/// Returns true if query node has settings changes specified, false otherwise
bool hasSettingsChanges() const
{
return !settings_changes.empty();
}
/// Get query node settings changes
const SettingsChanges & getSettingsChanges() const
{
return settings_changes;
}
/// Set query node settings changes value
void setSettingsChanges(SettingsChanges settings_changes_value)
{
settings_changes = std::move(settings_changes_value);
}
/// Get query node projection columns
const NamesAndTypes & getProjectionColumns() const
{
return projection_columns;
}
/// Resolve query node projection columns
void resolveProjectionColumns(NamesAndTypes projection_columns_value)
{
projection_columns = std::move(projection_columns_value);
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::QUERY;
}
String getName() const override;
DataTypePtr getResultType() const override
{
if (constant_value)
return constant_value->getType();
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method getResultType is not supported for non scalar query node");
}
/// Perform constant folding for scalar subquery node
void performConstantFolding(ConstantValuePtr constant_folded_value)
{
constant_value = std::move(constant_folded_value);
}
ConstantValuePtr getConstantValueOrNull() const override
{
return constant_value;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState &) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
bool is_subquery = false;
bool is_cte = false;
bool is_distinct = false;
bool is_limit_with_ties = false;
bool is_group_by_with_totals = false;
bool is_group_by_with_rollup = false;
bool is_group_by_with_cube = false;
bool is_group_by_with_grouping_sets = false;
std::string cte_name;
NamesAndTypes projection_columns;
ConstantValuePtr constant_value;
std::optional<TableExpressionModifiers> table_expression_modifiers;
SettingsChanges settings_changes;
static constexpr size_t with_child_index = 0;
static constexpr size_t projection_child_index = 1;
static constexpr size_t join_tree_child_index = 2;
static constexpr size_t prewhere_child_index = 3;
static constexpr size_t where_child_index = 4;
static constexpr size_t group_by_child_index = 5;
static constexpr size_t having_child_index = 6;
static constexpr size_t window_child_index = 7;
static constexpr size_t order_by_child_index = 8;
static constexpr size_t interpolate_child_index = 9;
static constexpr size_t limit_by_limit_child_index = 10;
static constexpr size_t limit_by_offset_child_index = 11;
static constexpr size_t limit_by_child_index = 12;
static constexpr size_t limit_child_index = 13;
static constexpr size_t offset_child_index = 14;
static constexpr size_t children_size = offset_child_index + 1;
};
}

View File

@ -0,0 +1,887 @@
#include <Analyzer/QueryTreeBuilder.h>
#include <Common/FieldVisitorToString.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Parsers/ParserSelectQuery.h>
#include <Parsers/ParserSelectWithUnionQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTSelectIntersectExceptQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTAsterisk.h>
#include <Parsers/ASTQualifiedAsterisk.h>
#include <Parsers/ASTColumnsMatcher.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTWithElement.h>
#include <Parsers/ASTColumnsTransformers.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTInterpolateElement.h>
#include <Parsers/ASTSampleRatio.h>
#include <Parsers/ASTWindowDefinition.h>
#include <Parsers/ASTSetQuery.h>
#include <Analyzer/IdentifierNode.h>
#include <Analyzer/MatcherNode.h>
#include <Analyzer/ColumnTransformers.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/LambdaNode.h>
#include <Analyzer/SortNode.h>
#include <Analyzer/InterpolateNode.h>
#include <Analyzer/WindowNode.h>
#include <Analyzer/TableNode.h>
#include <Analyzer/TableFunctionNode.h>
#include <Analyzer/QueryNode.h>
#include <Analyzer/ArrayJoinNode.h>
#include <Analyzer/JoinNode.h>
#include <Analyzer/UnionNode.h>
#include <Databases/IDatabase.h>
#include <Interpreters/StorageID.h>
#include <Interpreters/Context.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
extern const int LOGICAL_ERROR;
extern const int EXPECTED_ALL_OR_ANY;
extern const int NOT_IMPLEMENTED;
extern const int BAD_ARGUMENTS;
}
namespace
{
class QueryTreeBuilder
{
public:
explicit QueryTreeBuilder(ASTPtr query_, ContextPtr context_);
QueryTreeNodePtr getQueryTreeNode()
{
return query_tree_node;
}
private:
QueryTreeNodePtr buildSelectOrUnionExpression(const ASTPtr & select_or_union_query, bool is_subquery, const std::string & cte_name) const;
QueryTreeNodePtr buildSelectWithUnionExpression(const ASTPtr & select_with_union_query, bool is_subquery, const std::string & cte_name) const;
QueryTreeNodePtr buildSelectIntersectExceptQuery(const ASTPtr & select_intersect_except_query, bool is_subquery, const std::string & cte_name) const;
QueryTreeNodePtr buildSelectExpression(const ASTPtr & select_query, bool is_subquery, const std::string & cte_name) const;
QueryTreeNodePtr buildSortList(const ASTPtr & order_by_expression_list) const;
QueryTreeNodePtr buildInterpolateList(const ASTPtr & interpolate_expression_list) const;
QueryTreeNodePtr buildWindowList(const ASTPtr & window_definition_list) const;
QueryTreeNodePtr buildExpressionList(const ASTPtr & expression_list) const;
QueryTreeNodePtr buildExpression(const ASTPtr & expression) const;
QueryTreeNodePtr buildWindow(const ASTPtr & window_definition) const;
QueryTreeNodePtr buildJoinTree(const ASTPtr & tables_in_select_query) const;
ColumnTransformersNodes buildColumnTransformers(const ASTPtr & matcher_expression, size_t start_child_index) const;
ASTPtr query;
ContextPtr context;
QueryTreeNodePtr query_tree_node;
};
QueryTreeBuilder::QueryTreeBuilder(ASTPtr query_, ContextPtr context_)
: query(query_->clone())
, context(std::move(context_))
{
if (query->as<ASTSelectWithUnionQuery>() ||
query->as<ASTSelectIntersectExceptQuery>() ||
query->as<ASTSelectQuery>())
query_tree_node = buildSelectOrUnionExpression(query, false /*is_subquery*/, {} /*cte_name*/);
else if (query->as<ASTExpressionList>())
query_tree_node = buildExpressionList(query);
else
query_tree_node = buildExpression(query);
}
QueryTreeNodePtr QueryTreeBuilder::buildSelectOrUnionExpression(const ASTPtr & select_or_union_query, bool is_subquery, const std::string & cte_name) const
{
QueryTreeNodePtr query_node;
if (select_or_union_query->as<ASTSelectWithUnionQuery>())
query_node = buildSelectWithUnionExpression(select_or_union_query, is_subquery /*is_subquery*/, cte_name /*cte_name*/);
else if (select_or_union_query->as<ASTSelectIntersectExceptQuery>())
query_node = buildSelectIntersectExceptQuery(select_or_union_query, is_subquery /*is_subquery*/, cte_name /*cte_name*/);
else if (select_or_union_query->as<ASTSelectQuery>())
query_node = buildSelectExpression(select_or_union_query, is_subquery /*is_subquery*/, cte_name /*cte_name*/);
else
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "SELECT or UNION query {} is not supported", select_or_union_query->formatForErrorMessage());
return query_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildSelectWithUnionExpression(const ASTPtr & select_with_union_query, bool is_subquery, const std::string & cte_name) const
{
auto & select_with_union_query_typed = select_with_union_query->as<ASTSelectWithUnionQuery &>();
auto & select_lists = select_with_union_query_typed.list_of_selects->as<ASTExpressionList &>();
if (select_lists.children.size() == 1)
return buildSelectOrUnionExpression(select_lists.children[0], is_subquery, cte_name);
auto union_node = std::make_shared<UnionNode>();
union_node->setIsSubquery(is_subquery);
union_node->setIsCTE(!cte_name.empty());
union_node->setCTEName(cte_name);
union_node->setUnionMode(select_with_union_query_typed.union_mode);
union_node->setUnionModes(select_with_union_query_typed.list_of_modes);
union_node->setOriginalAST(select_with_union_query);
size_t select_lists_children_size = select_lists.children.size();
for (size_t i = 0; i < select_lists_children_size; ++i)
{
auto & select_list_node = select_lists.children[i];
QueryTreeNodePtr query_node = buildSelectOrUnionExpression(select_list_node, false /*is_subquery*/, {} /*cte_name*/);
union_node->getQueries().getNodes().push_back(std::move(query_node));
}
return union_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildSelectIntersectExceptQuery(const ASTPtr & select_intersect_except_query, bool is_subquery, const std::string & cte_name) const
{
auto & select_intersect_except_query_typed = select_intersect_except_query->as<ASTSelectIntersectExceptQuery &>();
auto select_lists = select_intersect_except_query_typed.getListOfSelects();
if (select_lists.size() == 1)
return buildSelectExpression(select_lists[0], is_subquery, cte_name);
auto union_node = std::make_shared<UnionNode>();
union_node->setIsSubquery(is_subquery);
union_node->setIsCTE(!cte_name.empty());
union_node->setCTEName(cte_name);
if (select_intersect_except_query_typed.final_operator == ASTSelectIntersectExceptQuery::Operator::INTERSECT_ALL)
union_node->setUnionMode(SelectUnionMode::INTERSECT_ALL);
else if (select_intersect_except_query_typed.final_operator == ASTSelectIntersectExceptQuery::Operator::INTERSECT_DISTINCT)
union_node->setUnionMode(SelectUnionMode::INTERSECT_DISTINCT);
else if (select_intersect_except_query_typed.final_operator == ASTSelectIntersectExceptQuery::Operator::EXCEPT_ALL)
union_node->setUnionMode(SelectUnionMode::EXCEPT_ALL);
else if (select_intersect_except_query_typed.final_operator == ASTSelectIntersectExceptQuery::Operator::EXCEPT_DISTINCT)
union_node->setUnionMode(SelectUnionMode::EXCEPT_DISTINCT);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "UNION type is not initialized");
union_node->setUnionModes(SelectUnionModes(select_lists.size() - 1, union_node->getUnionMode()));
union_node->setOriginalAST(select_intersect_except_query);
size_t select_lists_size = select_lists.size();
for (size_t i = 0; i < select_lists_size; ++i)
{
auto & select_list_node = select_lists[i];
QueryTreeNodePtr query_node = buildSelectOrUnionExpression(select_list_node, false /*is_subquery*/, {} /*cte_name*/);
union_node->getQueries().getNodes().push_back(std::move(query_node));
}
return union_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildSelectExpression(const ASTPtr & select_query, bool is_subquery, const std::string & cte_name) const
{
const auto & select_query_typed = select_query->as<ASTSelectQuery &>();
auto current_query_tree = std::make_shared<QueryNode>();
current_query_tree->setIsSubquery(is_subquery);
current_query_tree->setIsCTE(!cte_name.empty());
current_query_tree->setCTEName(cte_name);
current_query_tree->setIsDistinct(select_query_typed.distinct);
current_query_tree->setIsLimitWithTies(select_query_typed.limit_with_ties);
current_query_tree->setIsGroupByWithTotals(select_query_typed.group_by_with_totals);
current_query_tree->setIsGroupByWithCube(select_query_typed.group_by_with_cube);
current_query_tree->setIsGroupByWithRollup(select_query_typed.group_by_with_rollup);
current_query_tree->setIsGroupByWithGroupingSets(select_query_typed.group_by_with_grouping_sets);
current_query_tree->setOriginalAST(select_query);
auto select_settings = select_query_typed.settings();
if (select_settings)
{
auto & set_query = select_settings->as<ASTSetQuery &>();
current_query_tree->setSettingsChanges(set_query.changes);
}
current_query_tree->getJoinTree() = buildJoinTree(select_query_typed.tables());
auto select_with_list = select_query_typed.with();
if (select_with_list)
current_query_tree->getWithNode() = buildExpressionList(select_with_list);
auto select_expression_list = select_query_typed.select();
if (select_expression_list)
current_query_tree->getProjectionNode() = buildExpressionList(select_expression_list);
auto prewhere_expression = select_query_typed.prewhere();
if (prewhere_expression)
current_query_tree->getPrewhere() = buildExpression(prewhere_expression);
auto where_expression = select_query_typed.where();
if (where_expression)
current_query_tree->getWhere() = buildExpression(where_expression);
auto group_by_list = select_query_typed.groupBy();
if (group_by_list)
{
auto & group_by_children = group_by_list->children;
if (current_query_tree->isGroupByWithGroupingSets())
{
auto grouping_sets_list_node = std::make_shared<ListNode>();
for (auto & grouping_sets_keys : group_by_children)
{
auto grouping_sets_keys_list_node = buildExpressionList(grouping_sets_keys);
current_query_tree->getGroupBy().getNodes().emplace_back(std::move(grouping_sets_keys_list_node));
}
}
else
{
current_query_tree->getGroupByNode() = buildExpressionList(group_by_list);
}
}
auto having_expression = select_query_typed.having();
if (having_expression)
current_query_tree->getHaving() = buildExpression(having_expression);
auto window_list = select_query_typed.window();
if (window_list)
current_query_tree->getWindowNode() = buildWindowList(window_list);
auto select_order_by_list = select_query_typed.orderBy();
if (select_order_by_list)
current_query_tree->getOrderByNode() = buildSortList(select_order_by_list);
auto interpolate_list = select_query_typed.interpolate();
if (interpolate_list)
current_query_tree->getInterpolate() = buildInterpolateList(interpolate_list);
auto select_limit_by_limit = select_query_typed.limitByLength();
if (select_limit_by_limit)
current_query_tree->getLimitByLimit() = buildExpression(select_limit_by_limit);
auto select_limit_by_offset = select_query_typed.limitOffset();
if (select_limit_by_offset)
current_query_tree->getLimitByOffset() = buildExpression(select_limit_by_offset);
auto select_limit_by = select_query_typed.limitBy();
if (select_limit_by)
current_query_tree->getLimitByNode() = buildExpressionList(select_limit_by);
auto select_limit = select_query_typed.limitLength();
if (select_limit)
current_query_tree->getLimit() = buildExpression(select_limit);
auto select_offset = select_query_typed.limitOffset();
if (select_offset)
current_query_tree->getOffset() = buildExpression(select_offset);
return current_query_tree;
}
QueryTreeNodePtr QueryTreeBuilder::buildSortList(const ASTPtr & order_by_expression_list) const
{
auto list_node = std::make_shared<ListNode>();
auto & expression_list_typed = order_by_expression_list->as<ASTExpressionList &>();
list_node->getNodes().reserve(expression_list_typed.children.size());
for (auto & expression : expression_list_typed.children)
{
const auto & order_by_element = expression->as<const ASTOrderByElement &>();
auto sort_direction = order_by_element.direction == 1 ? SortDirection::ASCENDING : SortDirection::DESCENDING;
std::optional<SortDirection> nulls_sort_direction;
if (order_by_element.nulls_direction_was_explicitly_specified)
nulls_sort_direction = order_by_element.nulls_direction == 1 ? SortDirection::ASCENDING : SortDirection::DESCENDING;
std::shared_ptr<Collator> collator;
if (order_by_element.collation)
collator = std::make_shared<Collator>(order_by_element.collation->as<ASTLiteral &>().value.get<String &>());
const auto & sort_expression_ast = order_by_element.children.at(0);
auto sort_expression = buildExpression(sort_expression_ast);
auto sort_node = std::make_shared<SortNode>(std::move(sort_expression),
sort_direction,
nulls_sort_direction,
std::move(collator),
order_by_element.with_fill);
if (order_by_element.fill_from)
sort_node->getFillFrom() = buildExpression(order_by_element.fill_from);
if (order_by_element.fill_to)
sort_node->getFillTo() = buildExpression(order_by_element.fill_to);
if (order_by_element.fill_step)
sort_node->getFillStep() = buildExpression(order_by_element.fill_step);
list_node->getNodes().push_back(std::move(sort_node));
}
return list_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildInterpolateList(const ASTPtr & interpolate_expression_list) const
{
auto list_node = std::make_shared<ListNode>();
auto & expression_list_typed = interpolate_expression_list->as<ASTExpressionList &>();
list_node->getNodes().reserve(expression_list_typed.children.size());
for (auto & expression : expression_list_typed.children)
{
const auto & interpolate_element = expression->as<const ASTInterpolateElement &>();
auto expression_to_interpolate = std::make_shared<IdentifierNode>(Identifier(interpolate_element.column));
auto interpolate_expression = buildExpression(interpolate_element.expr);
auto interpolate_node = std::make_shared<InterpolateNode>(std::move(expression_to_interpolate), std::move(interpolate_expression));
list_node->getNodes().push_back(std::move(interpolate_node));
}
return list_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildWindowList(const ASTPtr & window_definition_list) const
{
auto list_node = std::make_shared<ListNode>();
auto & expression_list_typed = window_definition_list->as<ASTExpressionList &>();
list_node->getNodes().reserve(expression_list_typed.children.size());
for (auto & window_list_element : expression_list_typed.children)
{
const auto & window_list_element_typed = window_list_element->as<const ASTWindowListElement &>();
auto window_node = buildWindow(window_list_element_typed.definition);
window_node->setAlias(window_list_element_typed.name);
list_node->getNodes().push_back(std::move(window_node));
}
return list_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildExpressionList(const ASTPtr & expression_list) const
{
auto list_node = std::make_shared<ListNode>();
auto & expression_list_typed = expression_list->as<ASTExpressionList &>();
list_node->getNodes().reserve(expression_list_typed.children.size());
for (auto & expression : expression_list_typed.children)
{
auto expression_node = buildExpression(expression);
list_node->getNodes().push_back(std::move(expression_node));
}
return list_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildExpression(const ASTPtr & expression) const
{
QueryTreeNodePtr result;
if (const auto * ast_identifier = expression->as<ASTIdentifier>())
{
auto identifier = Identifier(ast_identifier->name_parts);
result = std::make_shared<IdentifierNode>(std::move(identifier));
}
else if (const auto * asterisk = expression->as<ASTAsterisk>())
{
auto column_transformers = buildColumnTransformers(expression, 0 /*start_child_index*/);
result = std::make_shared<MatcherNode>(std::move(column_transformers));
}
else if (const auto * qualified_asterisk = expression->as<ASTQualifiedAsterisk>())
{
auto & qualified_identifier = qualified_asterisk->children.at(0)->as<ASTTableIdentifier &>();
auto column_transformers = buildColumnTransformers(expression, 1 /*start_child_index*/);
result = std::make_shared<MatcherNode>(Identifier(qualified_identifier.name_parts), std::move(column_transformers));
}
else if (const auto * ast_literal = expression->as<ASTLiteral>())
{
result = std::make_shared<ConstantNode>(ast_literal->value);
}
else if (const auto * function = expression->as<ASTFunction>())
{
if (function->is_lambda_function)
{
const auto & lambda_arguments_and_expression = function->arguments->as<ASTExpressionList &>().children;
auto & lambda_arguments_tuple = lambda_arguments_and_expression.at(0)->as<ASTFunction &>();
auto lambda_arguments_nodes = std::make_shared<ListNode>();
Names lambda_arguments;
NameSet lambda_arguments_set;
if (lambda_arguments_tuple.arguments)
{
const auto & lambda_arguments_list = lambda_arguments_tuple.arguments->as<ASTExpressionList &>().children;
for (const auto & lambda_argument : lambda_arguments_list)
{
const auto * lambda_argument_identifier = lambda_argument->as<ASTIdentifier>();
if (!lambda_argument_identifier)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Lambda {} argument is not identifier",
function->formatForErrorMessage());
if (lambda_argument_identifier->name_parts.size() > 1)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Lambda {} argument identifier must contain single part. Actual {}",
function->formatForErrorMessage(),
lambda_argument_identifier->full_name);
const auto & argument_name = lambda_argument_identifier->name_parts[0];
auto [_, inserted] = lambda_arguments_set.insert(argument_name);
if (!inserted)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Lambda {} multiple arguments with same name {}",
function->formatForErrorMessage(),
argument_name);
lambda_arguments.push_back(argument_name);
}
}
const auto & lambda_expression = lambda_arguments_and_expression.at(1);
auto lambda_expression_node = buildExpression(lambda_expression);
result = std::make_shared<LambdaNode>(std::move(lambda_arguments), std::move(lambda_expression_node));
}
else
{
auto function_node = std::make_shared<FunctionNode>(function->name);
if (function->parameters)
{
const auto & function_parameters_list = function->parameters->as<ASTExpressionList>()->children;
for (const auto & argument : function_parameters_list)
function_node->getParameters().getNodes().push_back(buildExpression(argument));
}
if (function->arguments)
{
const auto & function_arguments_list = function->arguments->as<ASTExpressionList>()->children;
for (const auto & argument : function_arguments_list)
function_node->getArguments().getNodes().push_back(buildExpression(argument));
}
if (function->is_window_function)
{
if (function->window_definition)
function_node->getWindowNode() = buildWindow(function->window_definition);
else
function_node->getWindowNode() = std::make_shared<IdentifierNode>(Identifier(function->window_name));
}
result = std::move(function_node);
}
}
else if (const auto * subquery = expression->as<ASTSubquery>())
{
auto subquery_query = subquery->children[0];
auto query_node = buildSelectWithUnionExpression(subquery_query, true /*is_subquery*/, {} /*cte_name*/);
result = std::move(query_node);
}
else if (const auto * with_element = expression->as<ASTWithElement>())
{
auto with_element_subquery = with_element->subquery->as<ASTSubquery &>().children.at(0);
auto query_node = buildSelectWithUnionExpression(with_element_subquery, true /*is_subquery*/, with_element->name /*cte_name*/);
result = std::move(query_node);
}
else if (const auto * columns_regexp_matcher = expression->as<ASTColumnsRegexpMatcher>())
{
auto column_transformers = buildColumnTransformers(expression, 0 /*start_child_index*/);
result = std::make_shared<MatcherNode>(columns_regexp_matcher->getMatcher(), std::move(column_transformers));
}
else if (const auto * columns_list_matcher = expression->as<ASTColumnsListMatcher>())
{
Identifiers column_list_identifiers;
column_list_identifiers.reserve(columns_list_matcher->column_list->children.size());
for (auto & column_list_child : columns_list_matcher->column_list->children)
{
auto & column_list_identifier = column_list_child->as<ASTIdentifier &>();
column_list_identifiers.emplace_back(Identifier{column_list_identifier.name_parts});
}
auto column_transformers = buildColumnTransformers(expression, 0 /*start_child_index*/);
result = std::make_shared<MatcherNode>(std::move(column_list_identifiers), std::move(column_transformers));
}
else if (const auto * qualified_columns_regexp_matcher = expression->as<ASTQualifiedColumnsRegexpMatcher>())
{
auto & qualified_identifier = qualified_columns_regexp_matcher->children.at(0)->as<ASTTableIdentifier &>();
auto column_transformers = buildColumnTransformers(expression, 1 /*start_child_index*/);
result = std::make_shared<MatcherNode>(Identifier(qualified_identifier.name_parts), qualified_columns_regexp_matcher->getMatcher(), std::move(column_transformers));
}
else if (const auto * qualified_columns_list_matcher = expression->as<ASTQualifiedColumnsListMatcher>())
{
auto & qualified_identifier = qualified_columns_list_matcher->children.at(0)->as<ASTTableIdentifier &>();
Identifiers column_list_identifiers;
column_list_identifiers.reserve(qualified_columns_list_matcher->column_list->children.size());
for (auto & column_list_child : qualified_columns_list_matcher->column_list->children)
{
auto & column_list_identifier = column_list_child->as<ASTIdentifier &>();
column_list_identifiers.emplace_back(Identifier{column_list_identifier.name_parts});
}
auto column_transformers = buildColumnTransformers(expression, 1 /*start_child_index*/);
result = std::make_shared<MatcherNode>(Identifier(qualified_identifier.name_parts), std::move(column_list_identifiers), std::move(column_transformers));
}
else
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
"Invalid expression. Expected identifier, literal, matcher, function, subquery. Actual {}",
expression->formatForErrorMessage());
}
result->setAlias(expression->tryGetAlias());
result->setOriginalAST(expression);
return result;
}
QueryTreeNodePtr QueryTreeBuilder::buildWindow(const ASTPtr & window_definition) const
{
const auto & window_definition_typed = window_definition->as<const ASTWindowDefinition &>();
WindowFrame window_frame;
if (!window_definition_typed.frame_is_default)
{
window_frame.is_default = false;
window_frame.type = window_definition_typed.frame_type;
window_frame.begin_type = window_definition_typed.frame_begin_type;
window_frame.begin_preceding = window_definition_typed.frame_begin_preceding;
window_frame.end_type = window_definition_typed.frame_end_type;
window_frame.end_preceding = window_definition_typed.frame_end_preceding;
}
auto window_node = std::make_shared<WindowNode>(window_frame);
window_node->setParentWindowName(window_definition_typed.parent_window_name);
if (window_definition_typed.partition_by)
window_node->getPartitionByNode() = buildExpressionList(window_definition_typed.partition_by);
if (window_definition_typed.order_by)
window_node->getOrderByNode() = buildSortList(window_definition_typed.order_by);
if (window_definition_typed.frame_begin_offset)
window_node->getFrameBeginOffsetNode() = buildExpression(window_definition_typed.frame_begin_offset);
if (window_definition_typed.frame_end_offset)
window_node->getFrameEndOffsetNode() = buildExpression(window_definition_typed.frame_end_offset);
window_node->setOriginalAST(window_definition);
return window_node;
}
QueryTreeNodePtr QueryTreeBuilder::buildJoinTree(const ASTPtr & tables_in_select_query) const
{
if (!tables_in_select_query)
{
/** If no table is specified in SELECT query we substitute system.one table.
* SELECT * FROM system.one;
*/
Identifier storage_identifier("system.one");
return std::make_shared<IdentifierNode>(storage_identifier);
}
auto & tables = tables_in_select_query->as<ASTTablesInSelectQuery &>();
QueryTreeNodes table_expressions;
for (const auto & table_element_untyped : tables.children)
{
const auto & table_element = table_element_untyped->as<ASTTablesInSelectQueryElement &>();
if (table_element.table_expression)
{
auto & table_expression = table_element.table_expression->as<ASTTableExpression &>();
std::optional<TableExpressionModifiers> table_expression_modifiers;
if (table_expression.final || table_expression.sample_size)
{
bool has_final = table_expression.final;
std::optional<TableExpressionModifiers::Rational> sample_size_ratio;
std::optional<TableExpressionModifiers::Rational> sample_offset_ratio;
if (table_expression.sample_size)
{
auto & ast_sample_size_ratio = table_expression.sample_size->as<ASTSampleRatio &>();
sample_size_ratio = ast_sample_size_ratio.ratio;
if (table_expression.sample_offset)
{
auto & ast_sample_offset_ratio = table_expression.sample_offset->as<ASTSampleRatio &>();
sample_offset_ratio = ast_sample_offset_ratio.ratio;
}
}
table_expression_modifiers = TableExpressionModifiers(has_final, sample_size_ratio, sample_offset_ratio);
}
if (table_expression.database_and_table_name)
{
auto & table_identifier_typed = table_expression.database_and_table_name->as<ASTTableIdentifier &>();
auto storage_identifier = Identifier(table_identifier_typed.name_parts);
QueryTreeNodePtr table_identifier_node;
if (table_expression_modifiers)
table_identifier_node = std::make_shared<IdentifierNode>(storage_identifier, *table_expression_modifiers);
else
table_identifier_node = std::make_shared<IdentifierNode>(storage_identifier);
table_identifier_node->setAlias(table_identifier_typed.tryGetAlias());
table_identifier_node->setOriginalAST(table_element.table_expression);
table_expressions.push_back(std::move(table_identifier_node));
}
else if (table_expression.subquery)
{
auto & subquery_expression = table_expression.subquery->as<ASTSubquery &>();
const auto & select_with_union_query = subquery_expression.children[0];
auto node = buildSelectWithUnionExpression(select_with_union_query, true /*is_subquery*/, {} /*cte_name*/);
node->setAlias(subquery_expression.tryGetAlias());
node->setOriginalAST(select_with_union_query);
if (table_expression_modifiers)
{
if (auto * query_node = node->as<QueryNode>())
query_node->setTableExpressionModifiers(*table_expression_modifiers);
else if (auto * union_node = node->as<UnionNode>())
union_node->setTableExpressionModifiers(*table_expression_modifiers);
else
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected table expression subquery node. Expected union or query. Actual {}",
node->formatASTForErrorMessage());
}
table_expressions.push_back(std::move(node));
}
else if (table_expression.table_function)
{
auto & table_function_expression = table_expression.table_function->as<ASTFunction &>();
auto node = std::make_shared<TableFunctionNode>(table_function_expression.name);
if (table_function_expression.arguments)
{
const auto & function_arguments_list = table_function_expression.arguments->as<ASTExpressionList &>().children;
for (const auto & argument : function_arguments_list)
{
if (argument->as<ASTSelectQuery>() || argument->as<ASTSelectWithUnionQuery>() || argument->as<ASTSelectIntersectExceptQuery>())
node->getArguments().getNodes().push_back(buildSelectOrUnionExpression(argument, false /*is_subquery*/, {} /*cte_name*/));
else
node->getArguments().getNodes().push_back(buildExpression(argument));
}
}
if (table_expression_modifiers)
node->setTableExpressionModifiers(*table_expression_modifiers);
node->setAlias(table_function_expression.tryGetAlias());
node->setOriginalAST(table_expression.table_function);
table_expressions.push_back(std::move(node));
}
else
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Unsupported table expression node {}", table_element.table_expression->formatForErrorMessage());
}
}
if (table_element.table_join)
{
const auto & table_join = table_element.table_join->as<ASTTableJoin &>();
auto right_table_expression = std::move(table_expressions.back());
table_expressions.pop_back();
auto left_table_expression = std::move(table_expressions.back());
table_expressions.pop_back();
QueryTreeNodePtr join_expression;
if (table_join.using_expression_list)
join_expression = buildExpressionList(table_join.using_expression_list);
else if (table_join.on_expression)
join_expression = buildExpression(table_join.on_expression);
const auto & settings = context->getSettingsRef();
auto join_default_strictness = settings.join_default_strictness;
auto any_join_distinct_right_table_keys = settings.any_join_distinct_right_table_keys;
JoinStrictness result_join_strictness = table_join.strictness;
JoinKind result_join_kind = table_join.kind;
if (result_join_strictness == JoinStrictness::Unspecified && (result_join_kind != JoinKind::Cross && result_join_kind != JoinKind::Comma))
{
if (join_default_strictness == JoinStrictness::Any)
result_join_strictness = JoinStrictness::Any;
else if (join_default_strictness == JoinStrictness::All)
result_join_strictness = JoinStrictness::All;
else
throw Exception(ErrorCodes::EXPECTED_ALL_OR_ANY,
"Expected ANY or ALL in JOIN section, because setting (join_default_strictness) is empty");
}
if (any_join_distinct_right_table_keys)
{
if (result_join_strictness == JoinStrictness::Any && result_join_kind == JoinKind::Inner)
{
result_join_strictness = JoinStrictness::Semi;
result_join_kind = JoinKind::Left;
}
if (result_join_strictness == JoinStrictness::Any)
result_join_strictness = JoinStrictness::RightAny;
}
else if (result_join_strictness == JoinStrictness::Any && result_join_kind == JoinKind::Full)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "ANY FULL JOINs are not implemented");
}
auto join_node = std::make_shared<JoinNode>(std::move(left_table_expression),
std::move(right_table_expression),
std::move(join_expression),
table_join.locality,
result_join_strictness,
result_join_kind);
/** Original AST is not set because it will contain only join part and does
* not include left table expression.
*/
table_expressions.emplace_back(std::move(join_node));
}
if (table_element.array_join)
{
auto & array_join_expression = table_element.array_join->as<ASTArrayJoin &>();
bool is_left_array_join = array_join_expression.kind == ASTArrayJoin::Kind::Left;
auto last_table_expression = std::move(table_expressions.back());
table_expressions.pop_back();
auto array_join_expressions_list = buildExpressionList(array_join_expression.expression_list);
auto array_join_node = std::make_shared<ArrayJoinNode>(std::move(last_table_expression), std::move(array_join_expressions_list), is_left_array_join);
/** Original AST is not set because it will contain only array join part and does
* not include left table expression.
*/
table_expressions.push_back(std::move(array_join_node));
}
}
if (table_expressions.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Query FROM section cannot be empty");
if (table_expressions.size() > 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Query FROM section cannot have more than 1 root table expression");
return table_expressions.back();
}
ColumnTransformersNodes QueryTreeBuilder::buildColumnTransformers(const ASTPtr & matcher_expression, size_t start_child_index) const
{
ColumnTransformersNodes column_transformers;
size_t children_size = matcher_expression->children.size();
for (; start_child_index < children_size; ++start_child_index)
{
const auto & child = matcher_expression->children[start_child_index];
if (auto * apply_transformer = child->as<ASTColumnsApplyTransformer>())
{
if (apply_transformer->lambda)
{
auto lambda_query_tree_node = buildExpression(apply_transformer->lambda);
column_transformers.emplace_back(std::make_shared<ApplyColumnTransformerNode>(std::move(lambda_query_tree_node)));
}
else
{
auto function_node = std::make_shared<FunctionNode>(apply_transformer->func_name);
if (apply_transformer->parameters)
function_node->getParametersNode() = buildExpressionList(apply_transformer->parameters);
column_transformers.emplace_back(std::make_shared<ApplyColumnTransformerNode>(std::move(function_node)));
}
}
else if (auto * except_transformer = child->as<ASTColumnsExceptTransformer>())
{
auto matcher = except_transformer->getMatcher();
if (matcher)
{
column_transformers.emplace_back(std::make_shared<ExceptColumnTransformerNode>(std::move(matcher)));
}
else
{
Names except_column_names;
except_column_names.reserve(except_transformer->children.size());
for (auto & except_transformer_child : except_transformer->children)
except_column_names.push_back(except_transformer_child->as<ASTIdentifier &>().full_name);
column_transformers.emplace_back(std::make_shared<ExceptColumnTransformerNode>(std::move(except_column_names), except_transformer->is_strict));
}
}
else if (auto * replace_transformer = child->as<ASTColumnsReplaceTransformer>())
{
std::vector<ReplaceColumnTransformerNode::Replacement> replacements;
replacements.reserve(replace_transformer->children.size());
for (const auto & replace_transformer_child : replace_transformer->children)
{
auto & replacement = replace_transformer_child->as<ASTColumnsReplaceTransformer::Replacement &>();
replacements.emplace_back(ReplaceColumnTransformerNode::Replacement{replacement.name, buildExpression(replacement.expr)});
}
column_transformers.emplace_back(std::make_shared<ReplaceColumnTransformerNode>(replacements, replace_transformer->is_strict));
}
else
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Unsupported column matcher {}", child->formatForErrorMessage());
}
}
return column_transformers;
}
}
QueryTreeNodePtr buildQueryTree(ASTPtr query, ContextPtr context)
{
QueryTreeBuilder builder(std::move(query), context);
return builder.getQueryTreeNode();
}
}

View File

@ -0,0 +1,19 @@
#pragma once
#include <Parsers/IAST_fwd.h>
#include <Storages/IStorage_fwd.h>
#include <Interpreters/Context_fwd.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ColumnTransformers.h>
namespace DB
{
/** Build query tree from AST.
* AST that represent query ASTSelectWithUnionQuery, ASTSelectIntersectExceptQuery, ASTSelectQuery.
* AST that represent a list of expressions ASTExpressionList.
* AST that represent expression ASTIdentifier, ASTAsterisk, ASTLiteral, ASTFunction.
*/
QueryTreeNodePtr buildQueryTree(ASTPtr query, ContextPtr context);
}

View File

@ -0,0 +1,151 @@
#include <Analyzer/QueryTreePassManager.h>
#include <Analyzer/Passes/QueryAnalysisPass.h>
#include <Analyzer/Passes/CountDistinctPass.h>
#include <Analyzer/Passes/FunctionToSubcolumnsPass.h>
#include <Analyzer/Passes/SumIfToCountIfPass.h>
#include <Analyzer/Passes/MultiIfToIfPass.h>
#include <Analyzer/Passes/IfConstantConditionPass.h>
#include <Analyzer/Passes/IfChainToMultiIfPass.h>
#include <Analyzer/Passes/OrderByTupleEliminationPass.h>
#include <Analyzer/Passes/NormalizeCountVariantsPass.h>
#include <Analyzer/Passes/CustomizeFunctionsPass.h>
#include <Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.h>
#include <Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.h>
#include <Analyzer/Passes/OrderByLimitByDuplicateEliminationPass.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Interpreters/Context.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
/** ClickHouse query tree pass manager.
*
* TODO: Support _shard_num into shardNum() rewriting.
* TODO: Support logical expressions optimizer.
* TODO: Support fuse sum count optimize_fuse_sum_count_avg, optimize_syntax_fuse_functions.
* TODO: Support setting convert_query_to_cnf.
* TODO: Support setting optimize_using_constraints.
* TODO: Support setting optimize_substitute_columns.
* TODO: Support GROUP BY injective function elimination.
* TODO: Support GROUP BY functions of other keys elimination.
* TODO: Support setting optimize_move_functions_out_of_any.
* TODO: Support setting optimize_aggregators_of_group_by_keys.
* TODO: Support setting optimize_duplicate_order_by_and_distinct.
* TODO: Support setting optimize_redundant_functions_in_order_by.
* TODO: Support setting optimize_monotonous_functions_in_order_by.
* TODO: Support setting optimize_if_transform_strings_to_enum.
* TODO: Support settings.optimize_syntax_fuse_functions.
* TODO: Support settings.optimize_or_like_chain.
* TODO: Add optimizations based on function semantics. Example: SELECT * FROM test_table WHERE id != id. (id is not nullable column).
*/
QueryTreePassManager::QueryTreePassManager(ContextPtr context_) : WithContext(context_) {}
void QueryTreePassManager::addPass(QueryTreePassPtr pass)
{
passes.push_back(std::move(pass));
}
void QueryTreePassManager::run(QueryTreeNodePtr query_tree_node)
{
auto current_context = getContext();
size_t passes_size = passes.size();
for (size_t i = 0; i < passes_size; ++i)
passes[i]->run(query_tree_node, current_context);
}
void QueryTreePassManager::run(QueryTreeNodePtr query_tree_node, size_t up_to_pass_index)
{
size_t passes_size = passes.size();
if (up_to_pass_index > passes_size)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Requested to run passes up to {} pass. There are only {} passes",
up_to_pass_index,
passes_size);
auto current_context = getContext();
for (size_t i = 0; i < up_to_pass_index; ++i)
passes[i]->run(query_tree_node, current_context);
}
void QueryTreePassManager::dump(WriteBuffer & buffer)
{
size_t passes_size = passes.size();
for (size_t i = 0; i < passes_size; ++i)
{
auto & pass = passes[i];
buffer << "Pass " << (i + 1) << ' ' << pass->getName() << " - " << pass->getDescription();
if (i + 1 != passes_size)
buffer << '\n';
}
}
void QueryTreePassManager::dump(WriteBuffer & buffer, size_t up_to_pass_index)
{
size_t passes_size = passes.size();
if (up_to_pass_index > passes_size)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Requested to dump passes up to {} pass. There are only {} passes",
up_to_pass_index,
passes_size);
for (size_t i = 0; i < up_to_pass_index; ++i)
{
auto & pass = passes[i];
buffer << "Pass " << (i + 1) << " " << pass->getName() << " - " << pass->getDescription();
if (i + 1 != up_to_pass_index)
buffer << '\n';
}
}
void addQueryTreePasses(QueryTreePassManager & manager)
{
auto context = manager.getContext();
const auto & settings = context->getSettingsRef();
manager.addPass(std::make_shared<QueryAnalysisPass>());
if (settings.optimize_functions_to_subcolumns)
manager.addPass(std::make_shared<FunctionToSubcolumnsPass>());
if (settings.count_distinct_optimization)
manager.addPass(std::make_shared<CountDistinctPass>());
if (settings.optimize_rewrite_sum_if_to_count_if)
manager.addPass(std::make_shared<SumIfToCountIfPass>());
if (settings.optimize_normalize_count_variants)
manager.addPass(std::make_shared<NormalizeCountVariantsPass>());
manager.addPass(std::make_shared<CustomizeFunctionsPass>());
if (settings.optimize_arithmetic_operations_in_aggregate_functions)
manager.addPass(std::make_shared<AggregateFunctionsArithmericOperationsPass>());
if (settings.optimize_injective_functions_inside_uniq)
manager.addPass(std::make_shared<UniqInjectiveFunctionsEliminationPass>());
if (settings.optimize_multiif_to_if)
manager.addPass(std::make_shared<MultiIfToIfPass>());
manager.addPass(std::make_shared<IfConstantConditionPass>());
if (settings.optimize_if_chain_to_multiif)
manager.addPass(std::make_shared<IfChainToMultiIfPass>());
manager.addPass(std::make_shared<OrderByTupleEliminationPass>());
manager.addPass(std::make_shared<OrderByLimitByDuplicateEliminationPass>());
}
}

View File

@ -0,0 +1,49 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
/** Query tree pass manager provide functionality to register and run passes
* on query tree.
*/
class QueryTreePassManager : public WithContext
{
public:
explicit QueryTreePassManager(ContextPtr context_);
/// Get registered passes
const std::vector<QueryTreePassPtr> & getPasses() const
{
return passes;
}
/// Add query tree pass
void addPass(QueryTreePassPtr pass);
/// Run query tree passes on query tree
void run(QueryTreeNodePtr query_tree_node);
/** Run query tree passes on query tree up to up_to_pass_index.
* Throws exception if up_to_pass_index is greater than passes size.
*/
void run(QueryTreeNodePtr query_tree_node, size_t up_to_pass_index);
/// Dump query tree passes
void dump(WriteBuffer & buffer);
/** Dump query tree passes to up_to_pass_index.
* Throws exception if up_to_pass_index is greater than passes size.
*/
void dump(WriteBuffer & buffer, size_t up_to_pass_index);
private:
std::vector<QueryTreePassPtr> passes;
};
void addQueryTreePasses(QueryTreePassManager & manager);
}

182
src/Analyzer/SetUtils.cpp Normal file
View File

@ -0,0 +1,182 @@
#include <Analyzer/SetUtils.h>
#include <Core/Block.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Interpreters/convertFieldToType.h>
#include <Interpreters/Set.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INCORRECT_ELEMENT_OF_SET;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
namespace
{
size_t getCompoundTypeDepth(const IDataType & type)
{
size_t result = 0;
const IDataType * current_type = &type;
while (true)
{
WhichDataType which_type(*current_type);
if (which_type.isArray())
{
current_type = assert_cast<const DataTypeArray &>(*current_type).getNestedType().get();
++result;
}
else if (which_type.isTuple())
{
const auto & tuple_elements = assert_cast<const DataTypeTuple &>(*current_type).getElements();
if (!tuple_elements.empty())
current_type = tuple_elements.at(0).get();
++result;
}
else
{
break;
}
}
return result;
}
template <typename Collection>
Block createBlockFromCollection(const Collection & collection, const DataTypes & block_types, bool transform_null_in)
{
size_t columns_size = block_types.size();
MutableColumns columns(columns_size);
for (size_t i = 0; i < columns_size; ++i)
{
columns[i] = block_types[i]->createColumn();
columns[i]->reserve(collection.size());
}
Row tuple_values;
for (const auto & value : collection)
{
if (columns_size == 1)
{
auto field = convertFieldToType(value, *block_types[0]);
bool need_insert_null = transform_null_in && block_types[0]->isNullable();
if (!field.isNull() || need_insert_null)
columns[0]->insert(std::move(field));
continue;
}
if (value.getType() != Field::Types::Tuple)
throw Exception(ErrorCodes::INCORRECT_ELEMENT_OF_SET,
"Invalid type in set. Expected tuple, got {}",
value.getTypeName());
const auto & tuple = value.template get<const Tuple &>();
size_t tuple_size = tuple.size();
if (tuple_size != columns_size)
throw Exception(ErrorCodes::INCORRECT_ELEMENT_OF_SET,
"Incorrect size of tuple in set: {} instead of {}",
tuple_size,
columns_size);
if (tuple_values.empty())
tuple_values.resize(tuple_size);
size_t i = 0;
for (; i < tuple_size; ++i)
{
tuple_values[i] = convertFieldToType(tuple[i], *block_types[i]);
bool need_insert_null = transform_null_in && block_types[i]->isNullable();
if (tuple_values[i].isNull() && !need_insert_null)
break;
}
if (i == tuple_size)
for (i = 0; i < tuple_size; ++i)
columns[i]->insert(tuple_values[i]);
}
Block res;
for (size_t i = 0; i < columns_size; ++i)
res.insert(ColumnWithTypeAndName{std::move(columns[i]), block_types[i], "argument_" + toString(i)});
return res;
}
}
SetPtr makeSetForConstantValue(const DataTypePtr & expression_type, const Field & value, const DataTypePtr & value_type, const Settings & settings)
{
DataTypes set_element_types = {expression_type};
const auto * lhs_tuple_type = typeid_cast<const DataTypeTuple *>(expression_type.get());
if (lhs_tuple_type && lhs_tuple_type->getElements().size() != 1)
set_element_types = lhs_tuple_type->getElements();
for (auto & set_element_type : set_element_types)
{
if (const auto * set_element_low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(set_element_type.get()))
set_element_type = set_element_low_cardinality_type->getDictionaryType();
}
size_t lhs_type_depth = getCompoundTypeDepth(*expression_type);
size_t rhs_type_depth = getCompoundTypeDepth(*value_type);
SizeLimits size_limits_for_set = {settings.max_rows_in_set, settings.max_bytes_in_set, settings.set_overflow_mode};
bool tranform_null_in = settings.transform_null_in;
Block result_block;
if (lhs_type_depth == rhs_type_depth)
{
/// 1 in 1; (1, 2) in (1, 2); identity(tuple(tuple(tuple(1)))) in tuple(tuple(tuple(1))); etc.
Array array{value};
result_block = createBlockFromCollection(array, set_element_types, tranform_null_in);
}
else if (lhs_type_depth + 1 == rhs_type_depth)
{
/// 1 in (1, 2); (1, 2) in ((1, 2), (3, 4))
WhichDataType rhs_which_type(value_type);
if (rhs_which_type.isArray())
result_block = createBlockFromCollection(value.get<const Array &>(), set_element_types, tranform_null_in);
else if (rhs_which_type.isTuple())
result_block = createBlockFromCollection(value.get<const Tuple &>(), set_element_types, tranform_null_in);
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Unsupported type at the right-side of IN. Expected Array or Tuple. Actual {}",
value_type->getName());
}
else
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Unsupported types for IN. First argument type {}. Second argument type {}",
expression_type->getName(),
value_type->getName());
}
auto set = std::make_shared<Set>(size_limits_for_set, false /*fill_set_elements*/, tranform_null_in);
set->setHeader(result_block.cloneEmpty().getColumnsWithTypeAndName());
set->insertFromBlock(result_block.getColumnsWithTypeAndName());
set->finishInsert();
return set;
}
}

30
src/Analyzer/SetUtils.h Normal file
View File

@ -0,0 +1,30 @@
#pragma once
#include <Core/Settings.h>
#include <DataTypes/IDataType.h>
#include <QueryPipeline/SizeLimits.h>
namespace DB
{
class Set;
using SetPtr = std::shared_ptr<Set>;
/** Make set for constant part of IN subquery.
* Throws exception if parameters are not valid for IN function.
*
* Example: SELECT id FROM test_table WHERE id IN (1, 2, 3, 4);
* Example: SELECT id FROM test_table WHERE id IN ((1, 2), (3, 4));
*
* @param expression_type - type of first argument of function IN.
* @param value - constant value of second argument of function IN.
* @param value_type - type of second argument of function IN.
* @param settings - query settings.
*
* @return SetPtr for constant value.
*/
SetPtr makeSetForConstantValue(const DataTypePtr & expression_type, const Field & value, const DataTypePtr & value_type, const Settings & settings);
}

168
src/Analyzer/SortNode.cpp Normal file
View File

@ -0,0 +1,168 @@
#include <Analyzer/SortNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTLiteral.h>
namespace DB
{
const char * toString(SortDirection sort_direction)
{
switch (sort_direction)
{
case SortDirection::ASCENDING: return "ASCENDING";
case SortDirection::DESCENDING: return "DESCENDING";
}
}
SortNode::SortNode(QueryTreeNodePtr expression_,
SortDirection sort_direction_,
std::optional<SortDirection> nulls_sort_direction_,
std::shared_ptr<Collator> collator_,
bool with_fill_)
: IQueryTreeNode(children_size)
, sort_direction(sort_direction_)
, nulls_sort_direction(nulls_sort_direction_)
, collator(std::move(collator_))
, with_fill(with_fill_)
{
children[sort_expression_child_index] = std::move(expression_);
}
String SortNode::getName() const
{
String result = getExpression()->getName();
if (sort_direction == SortDirection::ASCENDING)
result += " ASC";
else
result += " DESC";
if (nulls_sort_direction)
{
if (*nulls_sort_direction == SortDirection::ASCENDING)
result += " NULLS FIRST";
else
result += " NULLS LAST";
}
if (with_fill)
result += " WITH FILL";
if (hasFillFrom())
result += " FROM " + getFillFrom()->getName();
if (hasFillStep())
result += " STEP " + getFillStep()->getName();
if (hasFillTo())
result += " TO " + getFillTo()->getName();
return result;
}
void SortNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "SORT id: " << format_state.getNodeId(this);
buffer << ", sort_direction: " << toString(sort_direction);
if (nulls_sort_direction)
buffer << ", nulls_sort_direction: " << toString(*nulls_sort_direction);
if (collator)
buffer << ", collator: " << collator->getLocale();
buffer << ", with_fill: " << with_fill;
buffer << '\n' << std::string(indent + 2, ' ') << "EXPRESSION\n";
getExpression()->dumpTreeImpl(buffer, format_state, indent + 4);
if (hasFillFrom())
{
buffer << '\n' << std::string(indent + 2, ' ') << "FILL FROM\n";
getFillFrom()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasFillTo())
{
buffer << '\n' << std::string(indent + 2, ' ') << "FILL TO\n";
getFillTo()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasFillStep())
{
buffer << '\n' << std::string(indent + 2, ' ') << "FILL STEP\n";
getFillStep()->dumpTreeImpl(buffer, format_state, indent + 4);
}
}
bool SortNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const SortNode &>(rhs);
if (sort_direction != rhs_typed.sort_direction ||
nulls_sort_direction != rhs_typed.nulls_sort_direction ||
with_fill != rhs_typed.with_fill)
return false;
if (!collator && !rhs_typed.collator)
return true;
else if (collator && !rhs_typed.collator)
return false;
else if (!collator && rhs_typed.collator)
return false;
return collator->getLocale() == rhs_typed.collator->getLocale();
}
void SortNode::updateTreeHashImpl(HashState & hash_state) const
{
hash_state.update(sort_direction);
hash_state.update(nulls_sort_direction);
hash_state.update(with_fill);
if (collator)
{
const auto & locale = collator->getLocale();
hash_state.update(locale.size());
hash_state.update(locale);
}
}
QueryTreeNodePtr SortNode::cloneImpl() const
{
return std::make_shared<SortNode>(nullptr /*expression*/, sort_direction, nulls_sort_direction, collator, with_fill);
}
ASTPtr SortNode::toASTImpl() const
{
auto result = std::make_shared<ASTOrderByElement>();
result->direction = sort_direction == SortDirection::ASCENDING ? 1 : -1;
result->nulls_direction = result->direction;
if (nulls_sort_direction)
result->nulls_direction = *nulls_sort_direction == SortDirection::ASCENDING ? 1 : -1;
result->nulls_direction_was_explicitly_specified = nulls_sort_direction.has_value();
result->with_fill = with_fill;
result->fill_from = hasFillFrom() ? getFillFrom()->toAST() : nullptr;
result->fill_to = hasFillTo() ? getFillTo()->toAST() : nullptr;
result->fill_step = hasFillStep() ? getFillStep()->toAST() : nullptr;
result->children.push_back(getExpression()->toAST());
if (collator)
{
result->children.push_back(std::make_shared<ASTLiteral>(Field(collator->getLocale())));
result->collation = result->children.back();
}
return result;
}
}

158
src/Analyzer/SortNode.h Normal file
View File

@ -0,0 +1,158 @@
#pragma once
#include <Columns/Collator.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
namespace DB
{
/** Sort node represents sort description for expression that is part of ORDER BY in query tree.
* Example: SELECT * FROM test_table ORDER BY sort_column_1, sort_column_2;
* Sort node optionally contain collation, fill from, fill to, and fill step.
*/
class SortNode;
using SortNodePtr = std::shared_ptr<SortNode>;
enum class SortDirection
{
ASCENDING = 0,
DESCENDING = 1
};
const char * toString(SortDirection sort_direction);
class SortNode final : public IQueryTreeNode
{
public:
/// Initialize sort node with sort expression
explicit SortNode(QueryTreeNodePtr expression_,
SortDirection sort_direction_ = SortDirection::ASCENDING,
std::optional<SortDirection> nulls_sort_direction_ = {},
std::shared_ptr<Collator> collator_ = nullptr,
bool with_fill = false);
/// Get sort expression
const QueryTreeNodePtr & getExpression() const
{
return children[sort_expression_child_index];
}
/// Get sort expression
QueryTreeNodePtr & getExpression()
{
return children[sort_expression_child_index];
}
/// Returns true if sort node has with fill, false otherwise
bool withFill() const
{
return with_fill;
}
/// Returns true if sort node has fill from, false otherwise
bool hasFillFrom() const
{
return children[fill_from_child_index] != nullptr;
}
/// Get fill from
const QueryTreeNodePtr & getFillFrom() const
{
return children[fill_from_child_index];
}
/// Get fill from
QueryTreeNodePtr & getFillFrom()
{
return children[fill_from_child_index];
}
/// Returns true if sort node has fill to, false otherwise
bool hasFillTo() const
{
return children[fill_to_child_index] != nullptr;
}
/// Get fill to
const QueryTreeNodePtr & getFillTo() const
{
return children[fill_to_child_index];
}
/// Get fill to
QueryTreeNodePtr & getFillTo()
{
return children[fill_to_child_index];
}
/// Returns true if sort node has fill step, false otherwise
bool hasFillStep() const
{
return children[fill_step_child_index] != nullptr;
}
/// Get fill step
const QueryTreeNodePtr & getFillStep() const
{
return children[fill_step_child_index];
}
/// Get fill step
QueryTreeNodePtr & getFillStep()
{
return children[fill_step_child_index];
}
/// Get collator
const std::shared_ptr<Collator> & getCollator() const
{
return collator;
}
/// Get sort direction
SortDirection getSortDirection() const
{
return sort_direction;
}
/// Get nulls sort direction
std::optional<SortDirection> getNullsSortDirection() const
{
return nulls_sort_direction;
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::SORT;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
static constexpr size_t sort_expression_child_index = 0;
static constexpr size_t fill_from_child_index = 1;
static constexpr size_t fill_to_child_index = 2;
static constexpr size_t fill_step_child_index = 3;
static constexpr size_t children_size = fill_step_child_index + 1;
SortDirection sort_direction = SortDirection::ASCENDING;
std::optional<SortDirection> nulls_sort_direction;
std::shared_ptr<Collator> collator;
bool with_fill = false;
};
}

View File

@ -0,0 +1,42 @@
#include <Analyzer/TableExpressionModifiers.h>
#include <Common/SipHash.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
namespace DB
{
void TableExpressionModifiers::dump(WriteBuffer & buffer) const
{
buffer << "final: " << has_final;
if (sample_size_ratio)
buffer << ", sample_size: " << ASTSampleRatio::toString(*sample_size_ratio);
if (sample_offset_ratio)
buffer << ", sample_offset: " << ASTSampleRatio::toString(*sample_offset_ratio);
}
void TableExpressionModifiers::updateTreeHash(SipHash & hash_state) const
{
hash_state.update(has_final);
hash_state.update(sample_size_ratio.has_value());
hash_state.update(sample_offset_ratio.has_value());
if (sample_size_ratio.has_value())
{
hash_state.update(sample_size_ratio->numerator);
hash_state.update(sample_size_ratio->denominator);
}
if (sample_offset_ratio.has_value())
{
hash_state.update(sample_offset_ratio->numerator);
hash_state.update(sample_offset_ratio->denominator);
}
}
}

View File

@ -0,0 +1,77 @@
#pragma once
#include <Parsers/ASTSampleRatio.h>
namespace DB
{
/** Modifiers that can be used for table, table function and subquery in JOIN TREE.
*
* Example: SELECT * FROM test_table SAMPLE 0.1 OFFSET 0.1 FINAL
*/
class TableExpressionModifiers
{
public:
using Rational = ASTSampleRatio::Rational;
TableExpressionModifiers(bool has_final_,
std::optional<Rational> sample_size_ratio_,
std::optional<Rational> sample_offset_ratio_)
: has_final(has_final_)
, sample_size_ratio(sample_size_ratio_)
, sample_offset_ratio(sample_offset_ratio_)
{}
/// Returns true if final is specified, false otherwise
bool hasFinal() const
{
return has_final;
}
/// Returns true if sample size ratio is specified, false otherwise
bool hasSampleSizeRatio() const
{
return sample_size_ratio.has_value();
}
/// Get sample size ratio
std::optional<Rational> getSampleSizeRatio() const
{
return sample_size_ratio;
}
/// Returns true if sample offset ratio is specified, false otherwise
bool hasSampleOffsetRatio() const
{
return sample_offset_ratio.has_value();
}
/// Get sample offset ratio
std::optional<Rational> getSampleOffsetRatio() const
{
return sample_offset_ratio;
}
/// Dump into buffer
void dump(WriteBuffer & buffer) const;
/// Update tree hash
void updateTreeHash(SipHash & hash_state) const;
private:
bool has_final = false;
std::optional<Rational> sample_size_ratio;
std::optional<Rational> sample_offset_ratio;
};
inline bool operator==(const TableExpressionModifiers & lhs, const TableExpressionModifiers & rhs)
{
return lhs.hasFinal() == rhs.hasFinal() && lhs.getSampleSizeRatio() == rhs.getSampleSizeRatio() && lhs.getSampleOffsetRatio() == rhs.getSampleOffsetRatio();
}
inline bool operator!=(const TableExpressionModifiers & lhs, const TableExpressionModifiers & rhs)
{
return !(lhs == rhs);
}
}

View File

@ -0,0 +1,148 @@
#include <Analyzer/TableFunctionNode.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Storages/IStorage.h>
#include <Parsers/ASTFunction.h>
#include <Interpreters/Context.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
TableFunctionNode::TableFunctionNode(String table_function_name_)
: IQueryTreeNode(children_size)
, table_function_name(table_function_name_)
, storage_id("system", "one")
{
children[arguments_child_index] = std::make_shared<ListNode>();
}
void TableFunctionNode::resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context)
{
table_function = std::move(table_function_value);
storage = std::move(storage_value);
storage_id = storage->getStorageID();
storage_snapshot = storage->getStorageSnapshot(storage->getInMemoryMetadataPtr(), context);
}
const StorageID & TableFunctionNode::getStorageID() const
{
if (!storage)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Table function node {} is not resolved", table_function_name);
return storage_id;
}
const StorageSnapshotPtr & TableFunctionNode::getStorageSnapshot() const
{
if (!storage)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Table function node {} is not resolved", table_function_name);
return storage_snapshot;
}
String TableFunctionNode::getName() const
{
String name = table_function_name;
const auto & arguments = getArguments();
name += '(';
name += arguments.getName();
name += ')';
return name;
}
void TableFunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "TABLE_FUNCTION id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
buffer << ", table_function_name: " << table_function_name;
if (table_expression_modifiers)
{
buffer << ", ";
table_expression_modifiers->dump(buffer);
}
const auto & arguments = getArguments();
if (!arguments.getNodes().empty())
{
buffer << '\n' << std::string(indent + 2, ' ') << "ARGUMENTS\n";
arguments.dumpTreeImpl(buffer, format_state, indent + 4);
}
}
bool TableFunctionNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const TableFunctionNode &>(rhs);
if (table_function_name != rhs_typed.table_function_name)
return false;
if (storage && rhs_typed.storage)
return storage_id == rhs_typed.storage_id;
if (table_expression_modifiers && rhs_typed.table_expression_modifiers && table_expression_modifiers != rhs_typed.table_expression_modifiers)
return false;
else if (table_expression_modifiers && !rhs_typed.table_expression_modifiers)
return false;
else if (!table_expression_modifiers && rhs_typed.table_expression_modifiers)
return false;
return true;
}
void TableFunctionNode::updateTreeHashImpl(HashState & state) const
{
state.update(table_function_name.size());
state.update(table_function_name);
if (storage)
{
auto full_name = storage_id.getFullNameNotQuoted();
state.update(full_name.size());
state.update(full_name);
}
if (table_expression_modifiers)
table_expression_modifiers->updateTreeHash(state);
}
QueryTreeNodePtr TableFunctionNode::cloneImpl() const
{
auto result = std::make_shared<TableFunctionNode>(table_function_name);
result->storage = storage;
result->storage_id = storage_id;
result->storage_snapshot = storage_snapshot;
result->table_expression_modifiers = table_expression_modifiers;
return result;
}
ASTPtr TableFunctionNode::toASTImpl() const
{
auto table_function_ast = std::make_shared<ASTFunction>();
table_function_ast->name = table_function_name;
const auto & arguments = getArguments();
table_function_ast->children.push_back(arguments.toAST());
table_function_ast->arguments = table_function_ast->children.back();
return table_function_ast;
}
}

View File

@ -0,0 +1,156 @@
#pragma once
#include <Storages/IStorage_fwd.h>
#include <Storages/TableLockHolder.h>
#include <Storages/StorageSnapshot.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/StorageID.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Analyzer/TableExpressionModifiers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
/** Table function node represents table function in query tree.
* Example: SELECT a FROM table_function(arguments...).
*
* In query tree table function arguments are represented by ListNode.
*
* Table function resolution must be done during query analysis pass.
*/
class ITableFunction;
using TableFunctionPtr = std::shared_ptr<ITableFunction>;
class TableFunctionNode;
using TableFunctionNodePtr = std::shared_ptr<TableFunctionNode>;
class TableFunctionNode : public IQueryTreeNode
{
public:
/// Construct table function node with table function name
explicit TableFunctionNode(String table_function_name);
/// Get table function name
const String & getTableFunctionName() const
{
return table_function_name;
}
/// Get arguments
const ListNode & getArguments() const
{
return children[arguments_child_index]->as<const ListNode &>();
}
/// Get arguments
ListNode & getArguments()
{
return children[arguments_child_index]->as<ListNode &>();
}
/// Get arguments node
const QueryTreeNodePtr & getArgumentsNode() const
{
return children[arguments_child_index];
}
/// Get arguments node
QueryTreeNodePtr & getArgumentsNode()
{
return children[arguments_child_index];
}
/// Returns true, if table function is resolved, false otherwise
bool isResolved() const
{
return storage != nullptr && table_function != nullptr;
}
/// Get table function, returns nullptr if table function node is not resolved
const TableFunctionPtr & getTableFunction() const
{
return table_function;
}
/// Get storage, returns nullptr if table function node is not resolved
const StoragePtr & getStorage() const
{
return storage;
}
/// Get storage, throws exception if table function node is not resolved
const StoragePtr & getStorageOrThrow() const
{
if (!storage)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Table function node is not resolved");
return storage;
}
/// Resolve table function with table function, storage and context
void resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context);
/// Get storage id, throws exception if function node is not resolved
const StorageID & getStorageID() const;
/// Get storage snapshot, throws exception if function node is not resolved
const StorageSnapshotPtr & getStorageSnapshot() const;
/// Return true if table function node has table expression modifiers, false otherwise
bool hasTableExpressionModifiers() const
{
return table_expression_modifiers.has_value();
}
/// Get table expression modifiers
const std::optional<TableExpressionModifiers> & getTableExpressionModifiers() const
{
return table_expression_modifiers;
}
/// Set table expression modifiers
void setTableExpressionModifiers(TableExpressionModifiers table_expression_modifiers_value)
{
table_expression_modifiers = std::move(table_expression_modifiers_value);
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::TABLE_FUNCTION;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
String table_function_name;
TableFunctionPtr table_function;
StoragePtr storage;
StorageID storage_id;
StorageSnapshotPtr storage_snapshot;
std::optional<TableExpressionModifiers> table_expression_modifiers;
static constexpr size_t arguments_child_index = 0;
static constexpr size_t children_size = arguments_child_index + 1;
};
}

View File

@ -0,0 +1,87 @@
#include <Analyzer/TableNode.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTIdentifier.h>
#include <Storages/IStorage.h>
#include <Interpreters/Context.h>
namespace DB
{
TableNode::TableNode(StoragePtr storage_, StorageID storage_id_, TableLockHolder storage_lock_, StorageSnapshotPtr storage_snapshot_)
: IQueryTreeNode(children_size)
, storage(std::move(storage_))
, storage_id(std::move(storage_id_))
, storage_lock(std::move(storage_lock_))
, storage_snapshot(std::move(storage_snapshot_))
{}
TableNode::TableNode(StoragePtr storage_, TableLockHolder storage_lock_, StorageSnapshotPtr storage_snapshot_)
: TableNode(storage_, storage_->getStorageID(), std::move(storage_lock_), std::move(storage_snapshot_))
{
}
void TableNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "TABLE id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
buffer << ", table_name: " << storage_id.getFullNameNotQuoted();
if (table_expression_modifiers)
{
buffer << ", ";
table_expression_modifiers->dump(buffer);
}
}
bool TableNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const TableNode &>(rhs);
if (table_expression_modifiers && rhs_typed.table_expression_modifiers && table_expression_modifiers != rhs_typed.table_expression_modifiers)
return false;
else if (table_expression_modifiers && !rhs_typed.table_expression_modifiers)
return false;
else if (!table_expression_modifiers && rhs_typed.table_expression_modifiers)
return false;
return storage_id == rhs_typed.storage_id;
}
void TableNode::updateTreeHashImpl(HashState & state) const
{
auto full_name = storage_id.getFullNameNotQuoted();
state.update(full_name.size());
state.update(full_name);
if (table_expression_modifiers)
table_expression_modifiers->updateTreeHash(state);
}
String TableNode::getName() const
{
return storage->getStorageID().getFullNameNotQuoted();
}
QueryTreeNodePtr TableNode::cloneImpl() const
{
auto result_table_node = std::make_shared<TableNode>(storage, storage_id, storage_lock, storage_snapshot);
result_table_node->table_expression_modifiers = table_expression_modifiers;
return result_table_node;
}
ASTPtr TableNode::toASTImpl() const
{
return std::make_shared<ASTTableIdentifier>(storage_id.getDatabaseName(), storage_id.getTableName());
}
}

103
src/Analyzer/TableNode.h Normal file
View File

@ -0,0 +1,103 @@
#pragma once
#include <Storages/IStorage_fwd.h>
#include <Storages/TableLockHolder.h>
#include <Storages/StorageSnapshot.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/StorageID.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/TableExpressionModifiers.h>
namespace DB
{
/** Table node represents table in query tree.
* Example: SELECT a FROM test_table.
* test_table - is identifier, that during query analysis pass must be resolved into table node.
*/
class TableNode;
using TableNodePtr = std::shared_ptr<TableNode>;
class TableNode : public IQueryTreeNode
{
public:
/// Construct table node with storage, storage id, storage lock, storage snapshot
explicit TableNode(StoragePtr storage_, StorageID storage_id_, TableLockHolder storage_lock_, StorageSnapshotPtr storage_snapshot_);
/// Construct table node with storage, storage lock, storage snapshot
explicit TableNode(StoragePtr storage_, TableLockHolder storage_lock_, StorageSnapshotPtr storage_snapshot_);
/// Get storage
const StoragePtr & getStorage() const
{
return storage;
}
/// Get storage id
const StorageID & getStorageID() const
{
return storage_id;
}
/// Get storage snapshot
const StorageSnapshotPtr & getStorageSnapshot() const
{
return storage_snapshot;
}
/// Get storage lock
const TableLockHolder & getStorageLock() const
{
return storage_lock;
}
/// Return true if table node has table expression modifiers, false otherwise
bool hasTableExpressionModifiers() const
{
return table_expression_modifiers.has_value();
}
/// Get table expression modifiers
const std::optional<TableExpressionModifiers> & getTableExpressionModifiers() const
{
return table_expression_modifiers;
}
/// Set table expression modifiers
void setTableExpressionModifiers(TableExpressionModifiers table_expression_modifiers_value)
{
table_expression_modifiers = std::move(table_expression_modifiers_value);
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::TABLE;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
StoragePtr storage;
StorageID storage_id;
TableLockHolder storage_lock;
StorageSnapshotPtr storage_snapshot;
std::optional<TableExpressionModifiers> table_expression_modifiers;
static constexpr size_t children_size = 0;
};
}

254
src/Analyzer/UnionNode.cpp Normal file
View File

@ -0,0 +1,254 @@
#include <Analyzer/UnionNode.h>
#include <Common/SipHash.h>
#include <Common/FieldVisitorToString.h>
#include <Core/NamesAndTypes.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTFunction.h>
#include <Core/ColumnWithTypeAndName.h>
#include <DataTypes/getLeastSupertype.h>
#include <Analyzer/QueryNode.h>
#include <Analyzer/Utils.h>
namespace DB
{
namespace ErrorCodes
{
extern const int TYPE_MISMATCH;
}
UnionNode::UnionNode()
: IQueryTreeNode(children_size)
{
children[queries_child_index] = std::make_shared<ListNode>();
}
NamesAndTypes UnionNode::computeProjectionColumns() const
{
std::vector<NamesAndTypes> projections;
NamesAndTypes query_node_projection;
const auto & query_nodes = getQueries().getNodes();
projections.reserve(query_nodes.size());
for (const auto & query_node : query_nodes)
{
if (auto * query_node_typed = query_node->as<QueryNode>())
query_node_projection = query_node_typed->getProjectionColumns();
else if (auto * union_node_typed = query_node->as<UnionNode>())
query_node_projection = union_node_typed->computeProjectionColumns();
projections.push_back(query_node_projection);
if (query_node_projection.size() != projections.front().size())
throw Exception(ErrorCodes::TYPE_MISMATCH, "UNION different number of columns in queries");
}
NamesAndTypes result_columns;
size_t projections_size = projections.size();
DataTypes projection_column_types;
projection_column_types.resize(projections_size);
size_t columns_size = query_node_projection.size();
for (size_t column_index = 0; column_index < columns_size; ++column_index)
{
for (size_t projection_index = 0; projection_index < projections_size; ++projection_index)
projection_column_types[projection_index] = projections[projection_index][column_index].type;
auto result_type = getLeastSupertype(projection_column_types);
result_columns.emplace_back(projections.front()[column_index].name, std::move(result_type));
}
return result_columns;
}
String UnionNode::getName() const
{
WriteBufferFromOwnString buffer;
auto query_nodes = getQueries().getNodes();
size_t query_nodes_size = query_nodes.size();
for (size_t i = 0; i < query_nodes_size; ++i)
{
const auto & query_node = query_nodes[i];
buffer << query_node->getName();
if (i == 0)
continue;
auto query_union_mode = union_modes.at(i - 1);
if (query_union_mode == SelectUnionMode::UNION_DEFAULT)
buffer << "UNION";
else if (query_union_mode == SelectUnionMode::UNION_ALL)
buffer << "UNION ALL";
else if (query_union_mode == SelectUnionMode::UNION_DISTINCT)
buffer << "UNION DISTINCT";
else if (query_union_mode == SelectUnionMode::EXCEPT_DEFAULT)
buffer << "EXCEPT";
else if (query_union_mode == SelectUnionMode::EXCEPT_ALL)
buffer << "EXCEPT ALL";
else if (query_union_mode == SelectUnionMode::EXCEPT_DISTINCT)
buffer << "EXCEPT DISTINCT";
else if (query_union_mode == SelectUnionMode::INTERSECT_DEFAULT)
buffer << "INTERSECT";
else if (query_union_mode == SelectUnionMode::INTERSECT_ALL)
buffer << "INTERSECT ALL";
else if (query_union_mode == SelectUnionMode::INTERSECT_DISTINCT)
buffer << "INTERSECT DISTINCT";
}
return buffer.str();
}
void UnionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "UNION id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
if (is_subquery)
buffer << ", is_subquery: " << is_subquery;
if (is_cte)
buffer << ", is_cte: " << is_cte;
if (!cte_name.empty())
buffer << ", cte_name: " << cte_name;
if (constant_value)
{
buffer << ", constant_value: " << constant_value->getValue().dump();
buffer << ", constant_value_type: " << constant_value->getType()->getName();
}
if (table_expression_modifiers)
{
buffer << ", ";
table_expression_modifiers->dump(buffer);
}
buffer << ", union_mode: " << toString(union_mode);
size_t union_modes_size = union_modes.size();
buffer << '\n' << std::string(indent + 2, ' ') << "UNION MODES " << union_modes_size << '\n';
for (size_t i = 0; i < union_modes_size; ++i)
{
buffer << std::string(indent + 4, ' ');
auto query_union_mode = union_modes[i];
buffer << toString(query_union_mode);
if (i + 1 != union_modes_size)
buffer << '\n';
}
buffer << '\n' << std::string(indent + 2, ' ') << "QUERIES\n";
getQueriesNode()->dumpTreeImpl(buffer, format_state, indent + 4);
}
bool UnionNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const UnionNode &>(rhs);
if (constant_value && rhs_typed.constant_value && *constant_value != *rhs_typed.constant_value)
return false;
else if (constant_value && !rhs_typed.constant_value)
return false;
else if (!constant_value && rhs_typed.constant_value)
return false;
if (table_expression_modifiers && rhs_typed.table_expression_modifiers && table_expression_modifiers != rhs_typed.table_expression_modifiers)
return false;
else if (table_expression_modifiers && !rhs_typed.table_expression_modifiers)
return false;
else if (!table_expression_modifiers && rhs_typed.table_expression_modifiers)
return false;
return is_subquery == rhs_typed.is_subquery && is_cte == rhs_typed.is_cte && cte_name == rhs_typed.cte_name &&
union_mode == rhs_typed.union_mode && union_modes == rhs_typed.union_modes;
}
void UnionNode::updateTreeHashImpl(HashState & state) const
{
state.update(is_subquery);
state.update(is_cte);
state.update(cte_name.size());
state.update(cte_name);
state.update(static_cast<size_t>(union_mode));
state.update(union_modes.size());
for (const auto & query_union_mode : union_modes)
state.update(static_cast<size_t>(query_union_mode));
if (constant_value)
{
auto constant_dump = applyVisitor(FieldVisitorToString(), constant_value->getValue());
state.update(constant_dump.size());
state.update(constant_dump);
auto constant_value_type_name = constant_value->getType()->getName();
state.update(constant_value_type_name.size());
state.update(constant_value_type_name);
}
if (table_expression_modifiers)
table_expression_modifiers->updateTreeHash(state);
}
QueryTreeNodePtr UnionNode::cloneImpl() const
{
auto result_union_node = std::make_shared<UnionNode>();
result_union_node->is_subquery = is_subquery;
result_union_node->is_cte = is_cte;
result_union_node->cte_name = cte_name;
result_union_node->union_mode = union_mode;
result_union_node->union_modes = union_modes;
result_union_node->union_modes_set = union_modes_set;
result_union_node->constant_value = constant_value;
result_union_node->table_expression_modifiers = table_expression_modifiers;
return result_union_node;
}
ASTPtr UnionNode::toASTImpl() const
{
auto select_with_union_query = std::make_shared<ASTSelectWithUnionQuery>();
select_with_union_query->union_mode = union_mode;
if (union_mode != SelectUnionMode::UNION_DEFAULT &&
union_mode != SelectUnionMode::EXCEPT_DEFAULT &&
union_mode != SelectUnionMode::INTERSECT_DEFAULT)
select_with_union_query->is_normalized = true;
select_with_union_query->list_of_modes = union_modes;
select_with_union_query->set_of_modes = union_modes_set;
select_with_union_query->children.push_back(getQueriesNode()->toAST());
select_with_union_query->list_of_selects = select_with_union_query->children.back();
return select_with_union_query;
}
}

203
src/Analyzer/UnionNode.h Normal file
View File

@ -0,0 +1,203 @@
#pragma once
#include <Core/NamesAndTypes.h>
#include <Core/Field.h>
#include <Analyzer/Identifier.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Analyzer/TableExpressionModifiers.h>
#include <Parsers/SelectUnionMode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
}
/** Union node represents union of queries in query tree.
*
* Example: (SELECT id FROM test_table) UNION ALL (SELECT id FROM test_table_2);
* Example: (SELECT id FROM test_table) UNION DISTINCT (SELECT id FROM test_table_2);
* Example: (SELECT id FROM test_table) EXCEPT ALL (SELECT id FROM test_table_2);
* Example: (SELECT id FROM test_table) EXCEPT DISTINCT (SELECT id FROM test_table_2);
* Example: (SELECT id FROM test_table) INTERSECT ALL (SELECT id FROM test_table_2);
* Example: (SELECT id FROM test_table) INTERSECT DISTINCT (SELECT id FROM test_table_2);
*
* Union node can be used as CTE.
* Example: WITH cte_subquery AS ((SELECT id FROM test_table) UNION ALL (SELECT id FROM test_table_2)) SELECT * FROM cte_subquery;
*
* Union node can be used as scalar subquery.
* Example: SELECT (SELECT 1 UNION DISTINCT SELECT 1);
*
* During query analysis pass union node queries must be resolved.
*/
class UnionNode;
using UnionNodePtr = std::shared_ptr<UnionNode>;
class UnionNode final : public IQueryTreeNode
{
public:
explicit UnionNode();
/// Returns true if union node is subquery, false otherwise
bool isSubquery() const
{
return is_subquery;
}
/// Set union node is subquery value
void setIsSubquery(bool is_subquery_value)
{
is_subquery = is_subquery_value;
}
/// Returns true if union node is CTE, false otherwise
bool isCTE() const
{
return is_cte;
}
/// Set union node is CTE
void setIsCTE(bool is_cte_value)
{
is_cte = is_cte_value;
}
/// Get union node CTE name
const std::string & getCTEName() const
{
return cte_name;
}
/// Set union node CTE name
void setCTEName(std::string cte_name_value)
{
cte_name = std::move(cte_name_value);
}
/// Get union mode
SelectUnionMode getUnionMode() const
{
return union_mode;
}
/// Set union mode value
void setUnionMode(SelectUnionMode union_mode_value)
{
union_mode = union_mode_value;
}
/// Get union modes
const SelectUnionModes & getUnionModes() const
{
return union_modes;
}
/// Set union modes value
void setUnionModes(const SelectUnionModes & union_modes_value)
{
union_modes = union_modes_value;
union_modes_set = SelectUnionModesSet(union_modes.begin(), union_modes.end());
}
/// Get union node queries
const ListNode & getQueries() const
{
return children[queries_child_index]->as<const ListNode &>();
}
/// Get union node queries
ListNode & getQueries()
{
return children[queries_child_index]->as<ListNode &>();
}
/// Get union node queries node
const QueryTreeNodePtr & getQueriesNode() const
{
return children[queries_child_index];
}
/// Get union node queries node
QueryTreeNodePtr & getQueriesNode()
{
return children[queries_child_index];
}
/// Return true if union node has table expression modifiers, false otherwise
bool hasTableExpressionModifiers() const
{
return table_expression_modifiers.has_value();
}
/// Get table expression modifiers
const std::optional<TableExpressionModifiers> & getTableExpressionModifiers() const
{
return table_expression_modifiers;
}
/// Set table expression modifiers
void setTableExpressionModifiers(TableExpressionModifiers table_expression_modifiers_value)
{
table_expression_modifiers = std::move(table_expression_modifiers_value);
}
/// Compute union node projection columns
NamesAndTypes computeProjectionColumns() const;
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::UNION;
}
String getName() const override;
DataTypePtr getResultType() const override
{
if (constant_value)
return constant_value->getType();
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Method getResultType is not supported for non scalar union node");
}
/// Perform constant folding for scalar union node
void performConstantFolding(ConstantValuePtr constant_folded_value)
{
constant_value = std::move(constant_folded_value);
}
ConstantValuePtr getConstantValueOrNull() const override
{
return constant_value;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState &) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
bool is_subquery = false;
bool is_cte = false;
std::string cte_name;
SelectUnionMode union_mode;
SelectUnionModes union_modes;
SelectUnionModesSet union_modes_set;
ConstantValuePtr constant_value;
std::optional<TableExpressionModifiers> table_expression_modifiers;
static constexpr size_t queries_child_index = 0;
static constexpr size_t children_size = queries_child_index + 1;
};
}

334
src/Analyzer/Utils.cpp Normal file
View File

@ -0,0 +1,334 @@
#include <Analyzer/Utils.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTFunction.h>
#include <Analyzer/IdentifierNode.h>
#include <Analyzer/JoinNode.h>
#include <Analyzer/ArrayJoinNode.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/TableNode.h>
#include <Analyzer/TableFunctionNode.h>
#include <Analyzer/QueryNode.h>
#include <Analyzer/UnionNode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
bool isNodePartOfTree(const IQueryTreeNode * node, const IQueryTreeNode * root)
{
std::vector<const IQueryTreeNode *> nodes_to_process;
nodes_to_process.push_back(root);
while (!nodes_to_process.empty())
{
const auto * subtree_node = nodes_to_process.back();
nodes_to_process.pop_back();
if (subtree_node == node)
return true;
for (const auto & child : subtree_node->getChildren())
{
if (child)
nodes_to_process.push_back(child.get());
}
}
return false;
}
bool isNameOfInFunction(const std::string & function_name)
{
bool is_special_function_in = function_name == "in" ||
function_name == "globalIn" ||
function_name == "notIn" ||
function_name == "globalNotIn" ||
function_name == "nullIn" ||
function_name == "globalNullIn" ||
function_name == "notNullIn" ||
function_name == "globalNotNullIn" ||
function_name == "inIgnoreSet" ||
function_name == "globalInIgnoreSet" ||
function_name == "notInIgnoreSet" ||
function_name == "globalNotInIgnoreSet" ||
function_name == "nullInIgnoreSet" ||
function_name == "globalNullInIgnoreSet" ||
function_name == "notNullInIgnoreSet" ||
function_name == "globalNotNullInIgnoreSet";
return is_special_function_in;
}
static ASTPtr convertIntoTableExpressionAST(const QueryTreeNodePtr & table_expression_node)
{
ASTPtr table_expression_node_ast;
auto node_type = table_expression_node->getNodeType();
if (node_type == QueryTreeNodeType::IDENTIFIER)
{
const auto & identifier_node = table_expression_node->as<IdentifierNode &>();
const auto & identifier = identifier_node.getIdentifier();
if (identifier.getPartsSize() == 1)
table_expression_node_ast = std::make_shared<ASTTableIdentifier>(identifier[0]);
else if (identifier.getPartsSize() == 2)
table_expression_node_ast = std::make_shared<ASTTableIdentifier>(identifier[0], identifier[1]);
else
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Identifier for table expression must contain 1 or 2 parts. Actual '{}'",
identifier.getFullName());
}
else
{
table_expression_node_ast = table_expression_node->toAST();
}
auto result_table_expression = std::make_shared<ASTTableExpression>();
result_table_expression->children.push_back(table_expression_node_ast);
std::optional<TableExpressionModifiers> table_expression_modifiers;
if (node_type == QueryTreeNodeType::QUERY || node_type == QueryTreeNodeType::UNION)
{
if (auto * query_node = table_expression_node->as<QueryNode>())
table_expression_modifiers = query_node->getTableExpressionModifiers();
else if (auto * union_node = table_expression_node->as<UnionNode>())
table_expression_modifiers = union_node->getTableExpressionModifiers();
result_table_expression->subquery = result_table_expression->children.back();
}
else if (node_type == QueryTreeNodeType::TABLE || node_type == QueryTreeNodeType::IDENTIFIER)
{
if (auto * table_node = table_expression_node->as<TableNode>())
table_expression_modifiers = table_node->getTableExpressionModifiers();
else if (auto * identifier_node = table_expression_node->as<IdentifierNode>())
table_expression_modifiers = identifier_node->getTableExpressionModifiers();
result_table_expression->database_and_table_name = result_table_expression->children.back();
}
else if (node_type == QueryTreeNodeType::TABLE_FUNCTION)
{
if (auto * table_function_node = table_expression_node->as<TableFunctionNode>())
table_expression_modifiers = table_function_node->getTableExpressionModifiers();
result_table_expression->table_function = result_table_expression->children.back();
}
else
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Expected identifier, table, query, union or table function. Actual {}",
table_expression_node->formatASTForErrorMessage());
}
if (table_expression_modifiers)
{
result_table_expression->final = table_expression_modifiers->hasFinal();
const auto & sample_size_ratio = table_expression_modifiers->getSampleSizeRatio();
if (sample_size_ratio.has_value())
result_table_expression->sample_size = std::make_shared<ASTSampleRatio>(*sample_size_ratio);
const auto & sample_offset_ratio = table_expression_modifiers->getSampleOffsetRatio();
if (sample_offset_ratio.has_value())
result_table_expression->sample_offset = std::make_shared<ASTSampleRatio>(*sample_offset_ratio);
}
return result_table_expression;
}
void addTableExpressionOrJoinIntoTablesInSelectQuery(ASTPtr & tables_in_select_query_ast, const QueryTreeNodePtr & table_expression)
{
auto table_expression_node_type = table_expression->getNodeType();
switch (table_expression_node_type)
{
case QueryTreeNodeType::IDENTIFIER:
[[fallthrough]];
case QueryTreeNodeType::TABLE:
[[fallthrough]];
case QueryTreeNodeType::QUERY:
[[fallthrough]];
case QueryTreeNodeType::UNION:
[[fallthrough]];
case QueryTreeNodeType::TABLE_FUNCTION:
{
auto table_expression_ast = convertIntoTableExpressionAST(table_expression);
auto tables_in_select_query_element_ast = std::make_shared<ASTTablesInSelectQueryElement>();
tables_in_select_query_element_ast->children.push_back(std::move(table_expression_ast));
tables_in_select_query_element_ast->table_expression = tables_in_select_query_element_ast->children.back();
tables_in_select_query_ast->children.push_back(std::move(tables_in_select_query_element_ast));
break;
}
case QueryTreeNodeType::ARRAY_JOIN:
[[fallthrough]];
case QueryTreeNodeType::JOIN:
{
auto table_expression_tables_in_select_query_ast = table_expression->toAST();
tables_in_select_query_ast->children.reserve(table_expression_tables_in_select_query_ast->children.size());
for (auto && table_element_ast : table_expression_tables_in_select_query_ast->children)
tables_in_select_query_ast->children.push_back(std::move(table_element_ast));
break;
}
default:
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected node type for table expression. Expected identifier, table, table function, query, union, join or array join. Actual {}",
table_expression->getNodeTypeName());
}
}
}
QueryTreeNodes extractTableExpressions(const QueryTreeNodePtr & join_tree_node)
{
QueryTreeNodes result;
std::deque<QueryTreeNodePtr> nodes_to_process;
nodes_to_process.push_back(join_tree_node);
while (!nodes_to_process.empty())
{
auto node_to_process = std::move(nodes_to_process.front());
nodes_to_process.pop_front();
auto node_type = node_to_process->getNodeType();
switch (node_type)
{
case QueryTreeNodeType::TABLE:
[[fallthrough]];
case QueryTreeNodeType::QUERY:
[[fallthrough]];
case QueryTreeNodeType::UNION:
[[fallthrough]];
case QueryTreeNodeType::TABLE_FUNCTION:
{
result.push_back(std::move(node_to_process));
break;
}
case QueryTreeNodeType::ARRAY_JOIN:
{
auto & array_join_node = node_to_process->as<ArrayJoinNode &>();
nodes_to_process.push_front(array_join_node.getTableExpression());
break;
}
case QueryTreeNodeType::JOIN:
{
auto & join_node = node_to_process->as<JoinNode &>();
nodes_to_process.push_front(join_node.getRightTableExpression());
nodes_to_process.push_front(join_node.getLeftTableExpression());
break;
}
default:
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected node type for table expression. Expected table, table function, query, union, join or array join. Actual {}",
node_to_process->getNodeTypeName());
}
}
}
return result;
}
namespace
{
void buildTableExpressionsStackImpl(const QueryTreeNodePtr & join_tree_node, QueryTreeNodes & result)
{
auto node_type = join_tree_node->getNodeType();
switch (node_type)
{
case QueryTreeNodeType::TABLE:
[[fallthrough]];
case QueryTreeNodeType::QUERY:
[[fallthrough]];
case QueryTreeNodeType::UNION:
[[fallthrough]];
case QueryTreeNodeType::TABLE_FUNCTION:
{
result.push_back(join_tree_node);
break;
}
case QueryTreeNodeType::ARRAY_JOIN:
{
auto & array_join_node = join_tree_node->as<ArrayJoinNode &>();
buildTableExpressionsStackImpl(array_join_node.getTableExpression(), result);
result.push_back(join_tree_node);
break;
}
case QueryTreeNodeType::JOIN:
{
auto & join_node = join_tree_node->as<JoinNode &>();
buildTableExpressionsStackImpl(join_node.getLeftTableExpression(), result);
buildTableExpressionsStackImpl(join_node.getRightTableExpression(), result);
result.push_back(join_tree_node);
break;
}
default:
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected node type for table expression. Expected table, table function, query, union, join or array join. Actual {}",
join_tree_node->getNodeTypeName());
}
}
}
}
QueryTreeNodes buildTableExpressionsStack(const QueryTreeNodePtr & join_tree_node)
{
QueryTreeNodes result;
buildTableExpressionsStackImpl(join_tree_node, result);
return result;
}
QueryTreeNodePtr getColumnSourceForJoinNodeWithUsing(const QueryTreeNodePtr & join_node)
{
QueryTreeNodePtr column_source_node = join_node;
while (true)
{
auto column_source_node_type = column_source_node->getNodeType();
if (column_source_node_type == QueryTreeNodeType::TABLE ||
column_source_node_type == QueryTreeNodeType::TABLE_FUNCTION ||
column_source_node_type == QueryTreeNodeType::QUERY ||
column_source_node_type == QueryTreeNodeType::UNION)
{
break;
}
else if (column_source_node_type == QueryTreeNodeType::ARRAY_JOIN)
{
auto & array_join_node = column_source_node->as<ArrayJoinNode &>();
column_source_node = array_join_node.getTableExpression();
continue;
}
else if (column_source_node_type == QueryTreeNodeType::JOIN)
{
auto & join_node_typed = column_source_node->as<JoinNode &>();
column_source_node = isRight(join_node_typed.getKind()) ? join_node_typed.getRightTableExpression() : join_node_typed.getLeftTableExpression();
continue;
}
else
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected node type for table expression. Expected table, table function, query, union, join or array join. Actual {}",
column_source_node->getNodeTypeName());
}
}
return column_source_node;
}
}

39
src/Analyzer/Utils.h Normal file
View File

@ -0,0 +1,39 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/// Returns true if node part of root tree, false otherwise
bool isNodePartOfTree(const IQueryTreeNode * node, const IQueryTreeNode * root);
/// Returns true if function name is name of IN function or its variations, false otherwise
bool isNameOfInFunction(const std::string & function_name);
/** Add table expression in tables in select query children.
* If table expression node is not of identifier node, table node, query node, table function node, join node or array join node type throws logical error exception.
*/
void addTableExpressionOrJoinIntoTablesInSelectQuery(ASTPtr & tables_in_select_query_ast, const QueryTreeNodePtr & table_expression);
/// Extract table, table function, query, union from join tree
QueryTreeNodes extractTableExpressions(const QueryTreeNodePtr & join_tree_node);
/** Build table expressions stack that consists from table, table function, query, union, join, array join from join tree.
*
* Example: SELECT * FROM t1 INNER JOIN t2 INNER JOIN t3.
* Result table expressions stack:
* 1. t1 INNER JOIN t2 INNER JOIN t3
* 2. t3
* 3. t1 INNER JOIN t2
* 4. t2
* 5. t1
*/
QueryTreeNodes buildTableExpressionsStack(const QueryTreeNodePtr & join_tree_node);
/** Get column source for JOIN node with USING.
* Example: SELECT id FROM test_table_1 AS t1 INNER JOIN test_table_2 AS t2 USING (id);
*/
QueryTreeNodePtr getColumnSourceForJoinNodeWithUsing(const QueryTreeNodePtr & join_node);
}

View File

@ -0,0 +1,78 @@
#include <Analyzer/WindowFunctionsUtils.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_AGGREGATION;
}
namespace
{
class CollectWindowFunctionNodeVisitor : public ConstInDepthQueryTreeVisitor<CollectWindowFunctionNodeVisitor>
{
public:
explicit CollectWindowFunctionNodeVisitor(QueryTreeNodes * window_function_nodes_)
: window_function_nodes(window_function_nodes_)
{}
explicit CollectWindowFunctionNodeVisitor(String assert_no_window_functions_place_message_)
: assert_no_window_functions_place_message(std::move(assert_no_window_functions_place_message_))
{}
void visitImpl(const QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || !function_node->isWindowFunction())
return;
if (!assert_no_window_functions_place_message.empty())
throw Exception(ErrorCodes::ILLEGAL_AGGREGATION,
"Window function {} is found {} in query",
function_node->formatASTForErrorMessage(),
assert_no_window_functions_place_message);
if (window_function_nodes)
window_function_nodes->push_back(node);
}
static bool needChildVisit(const QueryTreeNodePtr &, const QueryTreeNodePtr & child_node)
{
return !(child_node->getNodeType() == QueryTreeNodeType::QUERY || child_node->getNodeType() == QueryTreeNodeType::UNION);
}
private:
QueryTreeNodes * window_function_nodes = nullptr;
String assert_no_window_functions_place_message;
};
}
QueryTreeNodes collectWindowFunctionNodes(const QueryTreeNodePtr & node)
{
QueryTreeNodes window_function_nodes;
CollectWindowFunctionNodeVisitor visitor(&window_function_nodes);
visitor.visit(node);
return window_function_nodes;
}
void collectWindowFunctionNodes(const QueryTreeNodePtr & node, QueryTreeNodes & result)
{
CollectWindowFunctionNodeVisitor visitor(&result);
visitor.visit(node);
}
void assertNoWindowFunctionNodes(const QueryTreeNodePtr & node, const String & assert_no_window_functions_place_message)
{
CollectWindowFunctionNodeVisitor visitor(assert_no_window_functions_place_message);
visitor.visit(node);
}
}

View File

@ -0,0 +1,23 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
/** Collect window function nodes in node children.
* Do not visit subqueries.
*/
QueryTreeNodes collectWindowFunctionNodes(const QueryTreeNodePtr & node);
/** Collect window function nodes in node children and add them into result.
* Do not visit subqueries.
*/
void collectWindowFunctionNodes(const QueryTreeNodePtr & node, QueryTreeNodes & result);
/** Assert that there are no window function nodes in node children.
* Do not visit subqueries.
*/
void assertNoWindowFunctionNodes(const QueryTreeNodePtr & node, const String & assert_no_window_functions_place_message);
}

213
src/Analyzer/WindowNode.cpp Normal file
View File

@ -0,0 +1,213 @@
#include <Analyzer/WindowNode.h>
#include <Common/SipHash.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Parsers/ASTWindowDefinition.h>
namespace DB
{
WindowNode::WindowNode(WindowFrame window_frame_)
: IQueryTreeNode(children_size)
, window_frame(std::move(window_frame_))
{
children[partition_by_child_index] = std::make_shared<ListNode>();
children[order_by_child_index] = std::make_shared<ListNode>();
}
String WindowNode::getName() const
{
String result;
if (hasPartitionBy())
{
result += "PARTITION BY";
result += getPartitionBy().getName();
}
if (hasOrderBy())
{
result += "ORDER BY";
result += getOrderBy().getName();
}
if (!window_frame.is_default)
{
if (hasPartitionBy() || hasOrderBy())
result += ' ';
if (window_frame.type == WindowFrame::FrameType::ROWS)
result += "ROWS";
else if (window_frame.type == WindowFrame::FrameType::GROUPS)
result += "GROUPS";
else if (window_frame.type == WindowFrame::FrameType::RANGE)
result += "RANGE";
result += " BETWEEN ";
if (window_frame.begin_type == WindowFrame::BoundaryType::Current)
{
result += "CURRENT ROW";
}
else if (window_frame.begin_type == WindowFrame::BoundaryType::Unbounded)
{
result += "UNBOUNDED";
result += " ";
result += (window_frame.begin_preceding ? "PRECEDING" : "FOLLOWING");
}
else
{
result += getFrameBeginOffsetNode()->getName();
result += " ";
result += (window_frame.begin_preceding ? "PRECEDING" : "FOLLOWING");
}
result += " AND ";
if (window_frame.end_type == WindowFrame::BoundaryType::Current)
{
result += "CURRENT ROW";
}
else if (window_frame.end_type == WindowFrame::BoundaryType::Unbounded)
{
result += "UNBOUNDED";
result += " ";
result += (window_frame.end_preceding ? "PRECEDING" : "FOLLOWING");
}
else
{
result += getFrameEndOffsetNode()->getName();
result += " ";
result += (window_frame.end_preceding ? "PRECEDING" : "FOLLOWING");
}
}
return result;
}
void WindowNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "WINDOW id: " << format_state.getNodeId(this);
if (hasAlias())
buffer << ", alias: " << getAlias();
if (!parent_window_name.empty())
buffer << ", parent_window_name: " << parent_window_name;
buffer << ", frame_type: " << window_frame.type;
auto window_frame_bound_type_to_string = [](WindowFrame::BoundaryType boundary_type, bool boundary_preceding)
{
std::string value;
if (boundary_type == WindowFrame::BoundaryType::Unbounded)
value = "unbounded";
else if (boundary_type == WindowFrame::BoundaryType::Current)
value = "current";
else if (boundary_type == WindowFrame::BoundaryType::Offset)
value = "offset";
if (boundary_type != WindowFrame::BoundaryType::Current)
{
if (boundary_preceding)
value += " preceding";
else
value += " following";
}
return value;
};
buffer << ", frame_begin_type: " << window_frame_bound_type_to_string(window_frame.begin_type, window_frame.begin_preceding);
buffer << ", frame_end_type: " << window_frame_bound_type_to_string(window_frame.end_type, window_frame.end_preceding);
if (hasPartitionBy())
{
buffer << '\n' << std::string(indent + 2, ' ') << "PARTITION BY\n";
getPartitionBy().dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasOrderBy())
{
buffer << '\n' << std::string(indent + 2, ' ') << "ORDER BY\n";
getOrderBy().dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasFrameBeginOffset())
{
buffer << '\n' << std::string(indent + 2, ' ') << "FRAME BEGIN OFFSET\n";
getFrameBeginOffsetNode()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasFrameEndOffset())
{
buffer << '\n' << std::string(indent + 2, ' ') << "FRAME END OFFSET\n";
getFrameEndOffsetNode()->dumpTreeImpl(buffer, format_state, indent + 4);
}
}
bool WindowNode::isEqualImpl(const IQueryTreeNode & rhs) const
{
const auto & rhs_typed = assert_cast<const WindowNode &>(rhs);
return window_frame == rhs_typed.window_frame && parent_window_name == rhs_typed.parent_window_name;
}
void WindowNode::updateTreeHashImpl(HashState & hash_state) const
{
hash_state.update(window_frame.is_default);
hash_state.update(window_frame.type);
hash_state.update(window_frame.begin_type);
hash_state.update(window_frame.begin_preceding);
hash_state.update(window_frame.end_type);
hash_state.update(window_frame.end_preceding);
hash_state.update(parent_window_name);
}
QueryTreeNodePtr WindowNode::cloneImpl() const
{
auto window_node = std::make_shared<WindowNode>(window_frame);
window_node->parent_window_name = parent_window_name;
return window_node;
}
ASTPtr WindowNode::toASTImpl() const
{
auto window_definition = std::make_shared<ASTWindowDefinition>();
window_definition->parent_window_name = parent_window_name;
window_definition->children.push_back(getPartitionByNode()->toAST());
window_definition->partition_by = window_definition->children.back();
window_definition->children.push_back(getOrderByNode()->toAST());
window_definition->order_by = window_definition->children.back();
window_definition->frame_is_default = window_frame.is_default;
window_definition->frame_type = window_frame.type;
window_definition->frame_begin_type = window_frame.begin_type;
window_definition->frame_begin_preceding = window_frame.begin_preceding;
if (hasFrameBeginOffset())
{
window_definition->children.push_back(getFrameBeginOffsetNode()->toAST());
window_definition->frame_begin_offset = window_definition->children.back();
}
window_definition->frame_end_type = window_frame.end_type;
window_definition->frame_end_preceding = window_frame.end_preceding;
if (hasFrameEndOffset())
{
window_definition->children.push_back(getFrameEndOffsetNode()->toAST());
window_definition->frame_end_offset = window_definition->children.back();
}
return window_definition;
}
}

193
src/Analyzer/WindowNode.h Normal file
View File

@ -0,0 +1,193 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Interpreters/WindowDescription.h>
namespace DB
{
/** Window node represents window function window description.
*
* Example: SELECT * FROM test_table WINDOW window AS (PARTITION BY id);
* window AS (PARTITION BY id) - window node.
*
* Example: SELECT count() OVER (PARTITION BY id) FROM test_table;
* PARTITION BY id - window node.
*
* Window node can also refer to its parent window node.
* Example: SELECT count() OVER (parent_window ORDER BY id) FROM test_table WINDOW parent_window AS (PARTITION BY id);
* parent_window ORDER BY id - window node.
*
* Window node initially initialized with window frame.
*
* If window frame has OFFSET begin type, additionally frame begin offset node must be initialized.
* If window frame has OFFSET end type, additionally frame end offset node must be initialized.
* During query analysis pass they must be resolved, validated and window node window frame offset constants must be updated.
*/
class WindowNode;
using WindowNodePtr = std::shared_ptr<WindowNode>;
class WindowNode final : public IQueryTreeNode
{
public:
/// Initialize window node with window frame
explicit WindowNode(WindowFrame window_frame_);
/// Get window node window frame
const WindowFrame & getWindowFrame() const
{
return window_frame;
}
/// Get window node window frame
WindowFrame & getWindowFrame()
{
return window_frame;
}
/// Returns true if window node has parent window name, false otherwise
bool hasParentWindowName() const
{
return parent_window_name.empty();
}
/// Get parent window name
const String & getParentWindowName() const
{
return parent_window_name;
}
/// Set parent window name
void setParentWindowName(String parent_window_name_value)
{
parent_window_name = std::move(parent_window_name_value);
}
/// Returns true if window node has order by, false otherwise
bool hasOrderBy() const
{
return !getOrderBy().getNodes().empty();
}
/// Get order by
const ListNode & getOrderBy() const
{
return children[order_by_child_index]->as<const ListNode &>();
}
/// Get order by
ListNode & getOrderBy()
{
return children[order_by_child_index]->as<ListNode &>();
}
/// Get order by node
const QueryTreeNodePtr & getOrderByNode() const
{
return children[order_by_child_index];
}
/// Get order by node
QueryTreeNodePtr & getOrderByNode()
{
return children[order_by_child_index];
}
/// Returns true if window node has partition by, false otherwise
bool hasPartitionBy() const
{
return !getPartitionBy().getNodes().empty();
}
/// Get partition by
const ListNode & getPartitionBy() const
{
return children[partition_by_child_index]->as<const ListNode &>();
}
/// Get partition by
ListNode & getPartitionBy()
{
return children[partition_by_child_index]->as<ListNode &>();
}
/// Get partition by node
const QueryTreeNodePtr & getPartitionByNode() const
{
return children[partition_by_child_index];
}
/// Get partition by node
QueryTreeNodePtr & getPartitionByNode()
{
return children[partition_by_child_index];
}
/// Returns true if window node has FRAME begin offset, false otherwise
bool hasFrameBeginOffset() const
{
return getFrameBeginOffsetNode() != nullptr;
}
/// Get FRAME begin offset node
const QueryTreeNodePtr & getFrameBeginOffsetNode() const
{
return children[frame_begin_offset_child_index];
}
/// Get FRAME begin offset node
QueryTreeNodePtr & getFrameBeginOffsetNode()
{
return children[frame_begin_offset_child_index];
}
/// Returns true if window node has FRAME end offset, false otherwise
bool hasFrameEndOffset() const
{
return getFrameEndOffsetNode() != nullptr;
}
/// Get FRAME end offset node
const QueryTreeNodePtr & getFrameEndOffsetNode() const
{
return children[frame_end_offset_child_index];
}
/// Get FRAME end offset node
QueryTreeNodePtr & getFrameEndOffsetNode()
{
return children[frame_end_offset_child_index];
}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::WINDOW;
}
String getName() const override;
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
protected:
bool isEqualImpl(const IQueryTreeNode & rhs) const override;
void updateTreeHashImpl(HashState & hash_state) const override;
QueryTreeNodePtr cloneImpl() const override;
ASTPtr toASTImpl() const override;
private:
static constexpr size_t order_by_child_index = 0;
static constexpr size_t partition_by_child_index = 1;
static constexpr size_t frame_begin_offset_child_index = 3;
static constexpr size_t frame_end_offset_child_index = 4;
static constexpr size_t children_size = frame_end_offset_child_index + 1;
WindowFrame window_frame;
String parent_window_name;
};
}

View File

@ -0,0 +1,3 @@
add_executable (query_analyzer query_analyzer.cpp)
target_include_directories (query_analyzer SYSTEM BEFORE PRIVATE ${SPARSEHASH_INCLUDE_DIR})
target_link_libraries (query_analyzer PRIVATE dbms)

View File

@ -0,0 +1,9 @@
#include <iostream>
int main(int argc, char ** argv)
{
(void)(argc);
(void)(argv);
return 0;
}

View File

View File

@ -0,0 +1,227 @@
#include <gtest/gtest.h>
#include <Analyzer/Identifier.h>
using namespace DB;
TEST(Identifier, IdentifierBasics)
{
{
Identifier identifier;
ASSERT_TRUE(identifier.empty());
ASSERT_TRUE(identifier.isEmpty());
ASSERT_EQ(identifier.getPartsSize(), 0);
ASSERT_FALSE(identifier.isShort());
ASSERT_FALSE(identifier.isCompound());
ASSERT_FALSE(identifier.startsWith("test"));
ASSERT_FALSE(identifier.endsWith("test"));
ASSERT_EQ(identifier.begin(), identifier.end());
ASSERT_EQ(identifier.getFullName(), "");
}
{
Identifier identifier("value");
ASSERT_FALSE(identifier.empty());
ASSERT_FALSE(identifier.isEmpty());
ASSERT_EQ(identifier.getPartsSize(), 1);
ASSERT_TRUE(identifier.isShort());
ASSERT_FALSE(identifier.isCompound());
ASSERT_EQ(identifier.front(), "value");
ASSERT_EQ(identifier.back(), "value");
ASSERT_FALSE(identifier.startsWith("test"));
ASSERT_FALSE(identifier.endsWith("test"));
ASSERT_TRUE(identifier.startsWith("value"));
ASSERT_TRUE(identifier.endsWith("value"));
ASSERT_EQ(identifier[0], "value");
ASSERT_NE(identifier.begin(), identifier.end());
ASSERT_EQ(identifier.getFullName(), "value");
}
{
Identifier identifier("value1.value2");
ASSERT_FALSE(identifier.empty());
ASSERT_FALSE(identifier.isEmpty());
ASSERT_EQ(identifier.getPartsSize(), 2);
ASSERT_FALSE(identifier.isShort());
ASSERT_TRUE(identifier.isCompound());
ASSERT_EQ(identifier.front(), "value1");
ASSERT_EQ(identifier.back(), "value2");
ASSERT_FALSE(identifier.startsWith("test"));
ASSERT_FALSE(identifier.endsWith("test"));
ASSERT_TRUE(identifier.startsWith("value1"));
ASSERT_TRUE(identifier.endsWith("value2"));
ASSERT_EQ(identifier[0], "value1");
ASSERT_EQ(identifier[1], "value2");
ASSERT_NE(identifier.begin(), identifier.end());
ASSERT_EQ(identifier.getFullName(), "value1.value2");
}
{
Identifier identifier1("value1.value2");
Identifier identifier2("value1.value2");
ASSERT_EQ(identifier1, identifier2);
}
{
Identifier identifier1("value1.value2");
Identifier identifier2("value1.value3");
ASSERT_NE(identifier1, identifier2);
}
}
TEST(Identifier, IdentifierPopParts)
{
{
Identifier identifier("value1.value2.value3");
ASSERT_EQ(identifier.getFullName(), "value1.value2.value3");
identifier.popLast();
ASSERT_EQ(identifier.getFullName(), "value1.value2");
identifier.popLast();
ASSERT_EQ(identifier.getFullName(), "value1");
identifier.popLast();
ASSERT_EQ(identifier.getFullName(), "");
ASSERT_TRUE(identifier.isEmpty());
}
{
Identifier identifier("value1.value2.value3");
ASSERT_EQ(identifier.getFullName(), "value1.value2.value3");
identifier.popFirst();
ASSERT_EQ(identifier.getFullName(), "value2.value3");
identifier.popFirst();
ASSERT_EQ(identifier.getFullName(), "value3");
identifier.popFirst();
ASSERT_EQ(identifier.getFullName(), "");
ASSERT_TRUE(identifier.isEmpty());
}
{
Identifier identifier("value1.value2.value3");
ASSERT_EQ(identifier.getFullName(), "value1.value2.value3");
identifier.popLast();
ASSERT_EQ(identifier.getFullName(), "value1.value2");
identifier.popFirst();
ASSERT_EQ(identifier.getFullName(), "value2");
identifier.popLast();
ASSERT_EQ(identifier.getFullName(), "");
ASSERT_TRUE(identifier.isEmpty());
}
}
TEST(Identifier, IdentifierViewBasics)
{
{
Identifier identifier;
IdentifierView identifier_view(identifier);
ASSERT_TRUE(identifier_view.empty());
ASSERT_TRUE(identifier_view.isEmpty());
ASSERT_EQ(identifier_view.getPartsSize(), 0);
ASSERT_FALSE(identifier_view.isShort());
ASSERT_FALSE(identifier_view.isCompound());
ASSERT_FALSE(identifier_view.startsWith("test"));
ASSERT_FALSE(identifier_view.endsWith("test"));
ASSERT_EQ(identifier_view.begin(), identifier_view.end());
ASSERT_EQ(identifier_view.getFullName(), "");
}
{
Identifier identifier("value");
IdentifierView identifier_view(identifier);
ASSERT_FALSE(identifier_view.empty());
ASSERT_FALSE(identifier_view.isEmpty());
ASSERT_EQ(identifier_view.getPartsSize(), 1);
ASSERT_TRUE(identifier_view.isShort());
ASSERT_FALSE(identifier_view.isCompound());
ASSERT_EQ(identifier_view.front(), "value");
ASSERT_EQ(identifier_view.back(), "value");
ASSERT_FALSE(identifier_view.startsWith("test"));
ASSERT_FALSE(identifier_view.endsWith("test"));
ASSERT_TRUE(identifier_view.startsWith("value"));
ASSERT_TRUE(identifier_view.endsWith("value"));
ASSERT_EQ(identifier_view[0], "value");
ASSERT_NE(identifier_view.begin(), identifier_view.end());
ASSERT_EQ(identifier_view.getFullName(), "value");
}
{
Identifier identifier("value1.value2");
IdentifierView identifier_view(identifier);
ASSERT_FALSE(identifier_view.empty());
ASSERT_FALSE(identifier_view.isEmpty());
ASSERT_EQ(identifier_view.getPartsSize(), 2);
ASSERT_FALSE(identifier_view.isShort());
ASSERT_TRUE(identifier_view.isCompound());
ASSERT_FALSE(identifier_view.startsWith("test"));
ASSERT_FALSE(identifier_view.endsWith("test"));
ASSERT_TRUE(identifier_view.startsWith("value1"));
ASSERT_TRUE(identifier_view.endsWith("value2"));
ASSERT_EQ(identifier_view[0], "value1");
ASSERT_EQ(identifier_view[1], "value2");
ASSERT_NE(identifier_view.begin(), identifier_view.end());
ASSERT_EQ(identifier_view.getFullName(), "value1.value2");
}
{
Identifier identifier1("value1.value2");
IdentifierView identifier_view1(identifier1);
Identifier identifier2("value1.value2");
IdentifierView identifier_view2(identifier2);
ASSERT_EQ(identifier_view1, identifier_view2);
}
{
Identifier identifier1("value1.value2");
IdentifierView identifier_view1(identifier1);
Identifier identifier2("value1.value3");
IdentifierView identifier_view2(identifier2);
ASSERT_NE(identifier_view1, identifier_view2);
}
}
TEST(Identifier, IdentifierViewPopParts)
{
{
Identifier identifier("value1.value2.value3");
IdentifierView identifier_view(identifier);
ASSERT_EQ(identifier_view.getFullName(), "value1.value2.value3");
identifier_view.popLast();
ASSERT_EQ(identifier_view.getFullName(), "value1.value2");
identifier_view.popLast();
ASSERT_EQ(identifier_view.getFullName(), "value1");
identifier_view.popLast();
ASSERT_EQ(identifier_view.getFullName(), "");
ASSERT_TRUE(identifier_view.isEmpty());
}
{
Identifier identifier("value1.value2.value3");
IdentifierView identifier_view(identifier);
ASSERT_EQ(identifier_view.getFullName(), "value1.value2.value3");
identifier_view.popFirst();
ASSERT_EQ(identifier_view.getFullName(), "value2.value3");
identifier_view.popFirst();
ASSERT_EQ(identifier_view.getFullName(), "value3");
identifier_view.popFirst();
ASSERT_EQ(identifier_view.getFullName(), "");
ASSERT_TRUE(identifier_view.isEmpty());
}
{
Identifier identifier("value1.value2.value3");
IdentifierView identifier_view(identifier);
ASSERT_EQ(identifier_view.getFullName(), "value1.value2.value3");
identifier_view.popLast();
ASSERT_EQ(identifier_view.getFullName(), "value1.value2");
identifier_view.popFirst();
ASSERT_EQ(identifier_view.getFullName(), "value2");
identifier_view.popLast();
ASSERT_EQ(identifier_view.getFullName(), "");
ASSERT_TRUE(identifier_view.isEmpty());
}
}

View File

@ -0,0 +1,86 @@
#include <gtest/gtest.h>
#include <DataTypes/DataTypesNumber.h>
#include <Analyzer/Identifier.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/ListNode.h>
using namespace DB;
class SourceNode final : public IQueryTreeNode
{
public:
SourceNode() : IQueryTreeNode(0 /*children_size*/) {}
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::TABLE;
}
void dumpTreeImpl(WriteBuffer &, FormatState &, size_t) const override
{
}
bool isEqualImpl(const IQueryTreeNode &) const override
{
return true;
}
void updateTreeHashImpl(HashState &) const override
{
}
QueryTreeNodePtr cloneImpl() const override
{
return std::make_shared<SourceNode>();
}
ASTPtr toASTImpl() const override
{
return nullptr;
}
};
TEST(QueryTreeNode, Clone)
{
{
auto source_node = std::make_shared<SourceNode>();
NameAndTypePair column_name_and_type("value", std::make_shared<DataTypeUInt64>());
auto column_node = std::make_shared<ColumnNode>(column_name_and_type, source_node);
ASSERT_EQ(column_node->getColumnSource().get(), source_node.get());
auto cloned_column_node = column_node->clone();
/// If in subtree source was not cloned, source pointer must remain same
ASSERT_NE(column_node.get(), cloned_column_node.get());
ASSERT_EQ(cloned_column_node->as<ColumnNode &>().getColumnSource().get(), source_node.get());
}
{
auto root_node = std::make_shared<ListNode>();
auto source_node = std::make_shared<SourceNode>();
NameAndTypePair column_name_and_type("value", std::make_shared<DataTypeUInt64>());
auto column_node = std::make_shared<ColumnNode>(column_name_and_type, source_node);
root_node->getNodes().push_back(source_node);
root_node->getNodes().push_back(column_node);
ASSERT_EQ(column_node->getColumnSource().get(), source_node.get());
auto cloned_root_node = std::static_pointer_cast<ListNode>(root_node->clone());
auto cloned_source_node = cloned_root_node->getNodes()[0];
auto cloned_column_node = std::static_pointer_cast<ColumnNode>(cloned_root_node->getNodes()[1]);
/** If in subtree source was cloned.
* Source pointer for node that was cloned must remain same.
* Source pointer for cloned node must be updated.
*/
ASSERT_NE(column_node.get(), cloned_column_node.get());
ASSERT_NE(source_node.get(), cloned_source_node.get());
ASSERT_EQ(column_node->getColumnSource().get(), source_node.get());
ASSERT_EQ(cloned_column_node->getColumnSource().get(), cloned_source_node.get());
}
}

View File

@ -66,6 +66,8 @@ add_subdirectory (Storages)
add_subdirectory (Parsers)
add_subdirectory (IO)
add_subdirectory (Functions)
add_subdirectory (Analyzer)
add_subdirectory (Planner)
add_subdirectory (Interpreters)
add_subdirectory (AggregateFunctions)
add_subdirectory (Client)
@ -254,6 +256,9 @@ add_object_library(clickhouse_datatypes_serializations DataTypes/Serializations)
add_object_library(clickhouse_databases Databases)
add_object_library(clickhouse_databases_mysql Databases/MySQL)
add_object_library(clickhouse_disks Disks)
add_object_library(clickhouse_analyzer Analyzer)
add_object_library(clickhouse_analyzer_passes Analyzer/Passes)
add_object_library(clickhouse_planner Planner)
add_object_library(clickhouse_interpreters Interpreters)
add_object_library(clickhouse_interpreters_cache Interpreters/Cache)
add_object_library(clickhouse_interpreters_access Interpreters/Access)

View File

@ -302,6 +302,7 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(Float, opentelemetry_start_trace_probability, 0., "Probability to start an OpenTelemetry trace for an incoming query.", 0) \
M(Bool, opentelemetry_trace_processors, false, "Collect OpenTelemetry spans for processors.", 0) \
M(Bool, prefer_column_name_to_alias, false, "Prefer using column names instead of aliases if possible.", 0) \
M(Bool, use_analyzer, false, "Use analyzer", 0) \
M(Bool, prefer_global_in_and_join, false, "If enabled, all IN/JOIN operators will be rewritten as GLOBAL IN/JOIN. It's useful when the to-be-joined tables are only available on the initiator and we need to always scatter their data on-the-fly during distributed processing with the GLOBAL keyword. It's also useful to reduce the need to access the external sources joining external tables.", 0) \
\
\

View File

@ -185,8 +185,10 @@ public:
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!data_type_function)
throw Exception("First argument for function " + getName() + " must be a function",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function {} must be a function. Actual {}",
getName(),
arguments[0].type->getName());
/// The types of the remaining arguments are already checked in getLambdaArgumentTypes.

View File

@ -13,6 +13,11 @@
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
}
class FunctionGroupingBase : public IFunction
{
protected:
@ -71,6 +76,22 @@ public:
}
};
class FunctionGrouping : public FunctionGroupingBase
{
public:
explicit FunctionGrouping(bool force_compatibility_)
: FunctionGroupingBase(ColumnNumbers(), force_compatibility_)
{}
String getName() const override { return "grouping"; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName &, const DataTypePtr &, size_t) const override
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
"Method executeImpl is not supported for 'grouping' function");
}
};
class FunctionGroupingOrdinary : public FunctionGroupingBase
{
public:

View File

@ -82,7 +82,10 @@ public:
const DataTypeTuple * tuple = checkAndGetDataType<DataTypeTuple>(tuple_col);
if (!tuple)
throw Exception("First argument for function " + getName() + " must be tuple or array of tuple.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function {} must be tuple or array of tuple. Actual {}",
getName(),
arguments[0].type->getName());
auto index = getElementNum(arguments[1].column, *tuple, number_of_arguments);
if (index.has_value())
@ -137,7 +140,10 @@ public:
const DataTypeTuple * tuple_type_concrete = checkAndGetDataType<DataTypeTuple>(tuple_type);
const ColumnTuple * tuple_col_concrete = checkAndGetColumn<ColumnTuple>(tuple_col);
if (!tuple_type_concrete || !tuple_col_concrete)
throw Exception("First argument for function " + getName() + " must be tuple or array of tuple.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function {} must be tuple or array of tuple. Actual {}",
getName(),
first_arg.type->getName());
auto index = getElementNum(arguments[1].column, *tuple_type_concrete, arguments.size());
@ -221,20 +227,18 @@ private:
std::optional<size_t> getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple, const size_t argument_size) const
{
if (
checkAndGetColumnConst<ColumnUInt8>(index_column.get())
|| checkAndGetColumnConst<ColumnUInt16>(index_column.get())
|| checkAndGetColumnConst<ColumnUInt32>(index_column.get())
|| checkAndGetColumnConst<ColumnUInt64>(index_column.get())
)
if (checkAndGetColumnConst<ColumnUInt8>(index_column.get())
|| checkAndGetColumnConst<ColumnUInt16>(index_column.get())
|| checkAndGetColumnConst<ColumnUInt32>(index_column.get())
|| checkAndGetColumnConst<ColumnUInt64>(index_column.get()))
{
size_t index = index_column->getUInt(0);
if (index == 0)
throw Exception("Indices in tuples are 1-based.", ErrorCodes::ILLEGAL_INDEX);
throw Exception(ErrorCodes::ILLEGAL_INDEX, "Indices in tuples are 1-based.");
if (index > tuple.getElements().size())
throw Exception("Index for tuple element is out of range.", ErrorCodes::ILLEGAL_INDEX);
throw Exception(ErrorCodes::ILLEGAL_INDEX, "Index for tuple element is out of range.");
return std::optional<size_t>(index - 1);
}
@ -253,7 +257,9 @@ private:
return std::nullopt;
}
else
throw Exception("Second argument to " + getName() + " must be a constant UInt or String", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Second argument to {} must be a constant UInt or String",
getName());
}
};

View File

@ -93,6 +93,16 @@ DirectKeyValueJoin::DirectKeyValueJoin(std::shared_ptr<TableJoin> table_join_,
LOG_TRACE(log, "Using direct join");
}
DirectKeyValueJoin::DirectKeyValueJoin(
std::shared_ptr<TableJoin> table_join_,
const Block & right_sample_block_,
std::shared_ptr<const IKeyValueEntity> storage_,
const Block & right_sample_block_with_storage_column_names_)
: DirectKeyValueJoin(table_join_, right_sample_block_, storage_)
{
right_sample_block_with_storage_column_names = right_sample_block_with_storage_column_names_;
}
bool DirectKeyValueJoin::addJoinedBlock(const Block &, bool)
{
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unreachable code reached");
@ -114,14 +124,15 @@ void DirectKeyValueJoin::joinBlock(Block & block, std::shared_ptr<ExtraBlock> &)
return;
Block original_right_block = originalRightBlock(right_sample_block, *table_join);
const Names & attribute_names = original_right_block.getNames();
Block right_block_to_use = right_sample_block_with_storage_column_names ? right_sample_block_with_storage_column_names : original_right_block;
const Names & attribute_names = right_block_to_use.getNames();
NullMap null_map;
Chunk joined_chunk = storage->getByKeys({key_col}, null_map, attribute_names);
/// Expected right block may differ from structure in storage, because of `join_use_nulls` or we just select not all joined attributes
Block sample_storage_block = storage->getSampleBlock(attribute_names);
MutableColumns result_columns = convertBlockStructure(sample_storage_block, original_right_block, joined_chunk.mutateColumns(), null_map);
MutableColumns result_columns = convertBlockStructure(sample_storage_block, right_block_to_use, joined_chunk.mutateColumns(), null_map);
for (size_t i = 0; i < result_columns.size(); ++i)
{

View File

@ -25,6 +25,12 @@ public:
const Block & right_sample_block_,
std::shared_ptr<const IKeyValueEntity> storage_);
DirectKeyValueJoin(
std::shared_ptr<TableJoin> table_join_,
const Block & right_sample_block_,
std::shared_ptr<const IKeyValueEntity> storage_,
const Block & right_sample_block_with_storage_column_names_);
virtual const TableJoin & getTableJoin() const override { return *table_join; }
virtual bool addJoinedBlock(const Block &, bool) override;
@ -52,6 +58,7 @@ private:
std::shared_ptr<TableJoin> table_join;
std::shared_ptr<const IKeyValueEntity> storage;
Block right_sample_block;
Block right_sample_block_with_storage_column_names;
Block sample_block_with_columns_to_add;
Poco::Logger * log;

View File

@ -1073,8 +1073,8 @@ void ExpressionActionsChain::JoinStep::finalize(const NameSet & required_output_
}
/// Result will also contain joined columns.
for (const auto & column_name : analyzed_join->columnsAddedByJoin())
required_names.emplace(column_name);
for (const auto & column : analyzed_join->columnsAddedByJoin())
required_names.emplace(column.name);
for (const auto & column : result_columns)
{

View File

@ -7,6 +7,7 @@
#include <TableFunctions/ITableFunction.h>
#include <TableFunctions/TableFunctionFactory.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/InterpreterSelectQueryAnalyzer.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterDescribeQuery.h>
#include <Interpreters/IdentifierSemantic.h>
@ -17,7 +18,6 @@
#include <Parsers/TablePropertiesQueriesASTs.h>
#include <DataTypes/NestedUtils.h>
namespace DB
{
@ -60,10 +60,9 @@ Block InterpreterDescribeQuery::getSampleBlock(bool include_subcolumns)
return block;
}
BlockIO InterpreterDescribeQuery::execute()
{
ColumnsDescription columns;
std::vector<ColumnDescription> columns;
StorageSnapshotPtr storage_snapshot;
const auto & ast = query_ptr->as<ASTDescribeQuery &>();
@ -72,14 +71,34 @@ BlockIO InterpreterDescribeQuery::execute()
if (table_expression.subquery)
{
auto names_and_types = InterpreterSelectWithUnionQuery::getSampleBlock(
table_expression.subquery->children.at(0), getContext()).getNamesAndTypesList();
columns = ColumnsDescription(std::move(names_and_types));
NamesAndTypesList names_and_types;
auto select_query = table_expression.subquery->children.at(0);
auto current_context = getContext();
if (settings.use_analyzer)
{
SelectQueryOptions select_query_options;
names_and_types = InterpreterSelectQueryAnalyzer(select_query, select_query_options, current_context).getSampleBlock().getNamesAndTypesList();
}
else
{
names_and_types = InterpreterSelectWithUnionQuery::getSampleBlock(select_query, current_context).getNamesAndTypesList();
}
for (auto && [name, type] : names_and_types)
{
ColumnDescription description;
description.name = std::move(name);
description.type = std::move(type);
columns.emplace_back(std::move(description));
}
}
else if (table_expression.table_function)
{
TableFunctionPtr table_function_ptr = TableFunctionFactory::instance().get(table_expression.table_function, getContext());
columns = table_function_ptr->getActualTableStructure(getContext());
auto table_function_column_descriptions = table_function_ptr->getActualTableStructure(getContext());
for (const auto & table_function_column_description : table_function_column_descriptions)
columns.emplace_back(table_function_column_description);
}
else
{
@ -90,7 +109,9 @@ BlockIO InterpreterDescribeQuery::execute()
auto metadata_snapshot = table->getInMemoryMetadataPtr();
storage_snapshot = table->getStorageSnapshot(metadata_snapshot, getContext());
columns = metadata_snapshot->getColumns();
auto metadata_column_descriptions = metadata_snapshot->getColumns();
for (const auto & metadata_column_description : metadata_column_descriptions)
columns.emplace_back(metadata_column_description);
}
bool extend_object_types = settings.describe_extend_object_types && storage_snapshot;

View File

@ -7,6 +7,7 @@
#include <Interpreters/InDepthNodeVisitor.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/InterpreterSelectQueryAnalyzer.h>
#include <Interpreters/InterpreterInsertQuery.h>
#include <Interpreters/Context.h>
#include <Interpreters/TableOverrideUtils.h>
@ -28,6 +29,9 @@
#include <Common/JSONBuilder.h>
#include <Analyzer/QueryTreeBuilder.h>
#include <Analyzer/QueryTreePassManager.h>
namespace DB
{
@ -155,6 +159,30 @@ struct QueryASTSettings
{"graph", graph},
{"optimize", optimize}
};
std::unordered_map<std::string, std::reference_wrapper<Int64>> integer_settings;
};
struct QueryTreeSettings
{
bool run_passes = false;
bool dump_passes = false;
bool dump_ast = false;
Int64 passes = -1;
constexpr static char name[] = "QUERY TREE";
std::unordered_map<std::string, std::reference_wrapper<bool>> boolean_settings =
{
{"run_passes", run_passes},
{"dump_passes", dump_passes},
{"dump_ast", dump_ast}
};
std::unordered_map<std::string, std::reference_wrapper<Int64>> integer_settings =
{
{"passes", passes}
};
};
struct QueryPlanSettings
@ -177,6 +205,8 @@ struct QueryPlanSettings
{"json", json},
{"sorting", query_plan_options.sorting},
};
std::unordered_map<std::string, std::reference_wrapper<Int64>> integer_settings;
};
struct QueryPipelineSettings
@ -193,18 +223,31 @@ struct QueryPipelineSettings
{"graph", graph},
{"compact", compact},
};
std::unordered_map<std::string, std::reference_wrapper<Int64>> integer_settings;
};
template <typename Settings>
struct ExplainSettings : public Settings
{
using Settings::boolean_settings;
using Settings::integer_settings;
bool has(const std::string & name_) const
{
return hasBooleanSetting(name_) || hasIntegerSetting(name_);
}
bool hasBooleanSetting(const std::string & name_) const
{
return boolean_settings.count(name_) > 0;
}
bool hasIntegerSetting(const std::string & name_) const
{
return integer_settings.count(name_) > 0;
}
void setBooleanSetting(const std::string & name_, bool value)
{
auto it = boolean_settings.find(name_);
@ -214,6 +257,15 @@ struct ExplainSettings : public Settings
it->second.get() = value;
}
void setIntegerSetting(const std::string & name_, Int64 value)
{
auto it = integer_settings.find(name_);
if (it == integer_settings.end())
throw Exception("Unknown setting for ExplainSettings: " + name_, ErrorCodes::LOGICAL_ERROR);
it->second.get() = value;
}
std::string getSettingsList() const
{
std::string res;
@ -224,6 +276,13 @@ struct ExplainSettings : public Settings
res += setting.first;
}
for (const auto & setting : integer_settings)
{
if (!res.empty())
res += ", ";
res += setting.first;
}
return res;
}
@ -246,15 +305,23 @@ ExplainSettings<Settings> checkAndGetSettings(const ASTPtr & ast_settings)
if (change.value.getType() != Field::Types::UInt64)
throw Exception(ErrorCodes::INVALID_SETTING_VALUE,
"Invalid type {} for setting \"{}\" only boolean settings are supported",
"Invalid type {} for setting \"{}\" only integer settings are supported",
change.value.getTypeName(), change.name);
auto value = change.value.get<UInt64>();
if (value > 1)
throw Exception("Invalid value " + std::to_string(value) + " for setting \"" + change.name +
"\". Only boolean settings are supported", ErrorCodes::INVALID_SETTING_VALUE);
if (settings.hasBooleanSetting(change.name))
{
auto value = change.value.get<UInt64>();
if (value > 1)
throw Exception("Invalid value " + std::to_string(value) + " for setting \"" + change.name +
"\". Expected boolean type", ErrorCodes::INVALID_SETTING_VALUE);
settings.setBooleanSetting(change.name, value);
settings.setBooleanSetting(change.name, value);
}
else
{
auto value = change.value.get<UInt64>();
settings.setIntegerSetting(change.name, value);
}
}
return settings;
@ -304,6 +371,46 @@ QueryPipeline InterpreterExplainQuery::executeImpl()
ast.getExplainedQuery()->format(IAST::FormatSettings(buf, false));
break;
}
case ASTExplainQuery::QueryTree:
{
if (ast.getExplainedQuery()->as<ASTSelectWithUnionQuery>() == nullptr)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Only SELECT is supported for EXPLAIN QUERY TREE query");
auto settings = checkAndGetSettings<QueryTreeSettings>(ast.getSettings());
auto query_tree = buildQueryTree(ast.getExplainedQuery(), getContext());
if (settings.run_passes)
{
auto query_tree_pass_manager = QueryTreePassManager(getContext());
addQueryTreePasses(query_tree_pass_manager);
size_t pass_index = settings.passes < 0 ? query_tree_pass_manager.getPasses().size() : static_cast<size_t>(settings.passes);
if (settings.dump_passes)
{
query_tree_pass_manager.dump(buf, pass_index);
if (pass_index > 0)
buf << '\n';
}
query_tree_pass_manager.run(query_tree, pass_index);
query_tree->dumpTree(buf);
}
else
{
query_tree->dumpTree(buf);
}
if (settings.dump_ast)
{
buf << '\n';
buf << '\n';
query_tree->toAST()->format(IAST::FormatSettings(buf, false));
}
break;
}
case ASTExplainQuery::QueryPlan:
{
if (!dynamic_cast<const ASTSelectWithUnionQuery *>(ast.getExplainedQuery().get()))
@ -312,8 +419,16 @@ QueryPipeline InterpreterExplainQuery::executeImpl()
auto settings = checkAndGetSettings<QueryPlanSettings>(ast.getSettings());
QueryPlan plan;
InterpreterSelectWithUnionQuery interpreter(ast.getExplainedQuery(), getContext(), options);
interpreter.buildQueryPlan(plan);
if (getContext()->getSettingsRef().use_analyzer)
{
InterpreterSelectQueryAnalyzer interpreter(ast.getExplainedQuery(), options, getContext());
plan = std::move(interpreter).extractQueryPlan();
}
else
{
InterpreterSelectWithUnionQuery interpreter(ast.getExplainedQuery(), getContext(), options);
interpreter.buildQueryPlan(plan);
}
if (settings.optimize)
plan.optimize(QueryPlanOptimizationSettings::fromContext(getContext()));
@ -347,8 +462,17 @@ QueryPipeline InterpreterExplainQuery::executeImpl()
auto settings = checkAndGetSettings<QueryPipelineSettings>(ast.getSettings());
QueryPlan plan;
InterpreterSelectWithUnionQuery interpreter(ast.getExplainedQuery(), getContext(), options);
interpreter.buildQueryPlan(plan);
if (getContext()->getSettingsRef().use_analyzer)
{
InterpreterSelectQueryAnalyzer interpreter(ast.getExplainedQuery(), options, getContext());
plan = std::move(interpreter).extractQueryPlan();
}
else
{
InterpreterSelectWithUnionQuery interpreter(ast.getExplainedQuery(), getContext(), options);
interpreter.buildQueryPlan(plan);
}
auto pipeline = plan.buildQueryPipeline(
QueryPlanOptimizationSettings::fromContext(getContext()),
BuildQueryPipelineSettings::fromContext(getContext()));

View File

@ -63,6 +63,7 @@
#include <Interpreters/InterpreterOptimizeQuery.h>
#include <Interpreters/InterpreterRenameQuery.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/InterpreterSelectQueryAnalyzer.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/InterpreterSetQuery.h>
#include <Interpreters/InterpreterShowCreateQuery.h>
@ -118,6 +119,9 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, ContextMut
if (query->as<ASTSelectQuery>())
{
if (context->getSettingsRef().use_analyzer)
return std::make_unique<InterpreterSelectQueryAnalyzer>(query, options, context);
/// This is internal part of ASTSelectWithUnionQuery.
/// Even if there is SELECT without union, it is represented by ASTSelectWithUnionQuery with single ASTSelectQuery as a child.
return std::make_unique<InterpreterSelectQuery>(query, context, options);
@ -125,6 +129,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, ContextMut
else if (query->as<ASTSelectWithUnionQuery>())
{
ProfileEvents::increment(ProfileEvents::SelectQuery);
if (context->getSettingsRef().use_analyzer)
return std::make_unique<InterpreterSelectQueryAnalyzer>(query, options, context);
return std::make_unique<InterpreterSelectWithUnionQuery>(query, context, options);
}
else if (query->as<ASTSelectIntersectExceptQuery>())

View File

@ -0,0 +1,120 @@
#include <Interpreters/InterpreterSelectQueryAnalyzer.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTSubquery.h>
#include <Analyzer/QueryTreeBuilder.h>
#include <Analyzer/QueryTreePassManager.h>
#include <Processors/QueryPlan/IQueryPlanStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <Interpreters/Context.h>
#include <Interpreters/QueryLog.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
}
namespace
{
ASTPtr normalizeAndValidateQuery(const ASTPtr & query)
{
if (query->as<ASTSelectWithUnionQuery>() || query->as<ASTSelectQuery>())
{
return query;
}
else if (auto * subquery = query->as<ASTSubquery>())
{
return subquery->children[0];
}
else
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
"Expected ASTSelectWithUnionQuery or ASTSelectQuery. Actual {}",
query->formatForErrorMessage());
}
}
QueryTreeNodePtr buildQueryTreeAndRunPasses(const ASTPtr & query, const ContextPtr & context)
{
auto query_tree = buildQueryTree(query, context);
QueryTreePassManager query_tree_pass_manager(context);
addQueryTreePasses(query_tree_pass_manager);
query_tree_pass_manager.run(query_tree);
return query_tree;
}
}
InterpreterSelectQueryAnalyzer::InterpreterSelectQueryAnalyzer(
const ASTPtr & query_,
const SelectQueryOptions & select_query_options_,
ContextPtr context_)
: WithContext(context_)
, query(normalizeAndValidateQuery(query_))
, query_tree(buildQueryTreeAndRunPasses(query, context_))
, select_query_options(select_query_options_)
, planner(query_tree, select_query_options, context_)
{
}
InterpreterSelectQueryAnalyzer::InterpreterSelectQueryAnalyzer(
const QueryTreeNodePtr & query_tree_,
const SelectQueryOptions & select_query_options_,
ContextPtr context_)
: WithContext(context_)
, query(query_tree_->toAST())
, query_tree(query_tree_)
, select_query_options(select_query_options_)
, planner(query_tree, select_query_options, context_)
{
}
Block InterpreterSelectQueryAnalyzer::getSampleBlock()
{
planner.buildQueryPlanIfNeeded();
return planner.getQueryPlan().getCurrentDataStream().header;
}
BlockIO InterpreterSelectQueryAnalyzer::execute()
{
planner.buildQueryPlanIfNeeded();
auto & query_plan = planner.getQueryPlan();
QueryPlanOptimizationSettings optimization_settings;
BuildQueryPipelineSettings build_pipeline_settings;
auto pipeline_builder = query_plan.buildQueryPipeline(optimization_settings, build_pipeline_settings);
BlockIO result;
result.pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder));
if (!select_query_options.ignore_quota && (select_query_options.to_stage == QueryProcessingStage::Complete))
result.pipeline.setQuota(getContext()->getQuota());
return result;
}
QueryPlan && InterpreterSelectQueryAnalyzer::extractQueryPlan() &&
{
planner.buildQueryPlanIfNeeded();
return std::move(planner).extractQueryPlan();
}
void InterpreterSelectQueryAnalyzer::extendQueryLogElemImpl(QueryLogElement & elem, const ASTPtr &, ContextPtr) const
{
elem.query_kind = "Select";
}
}

View File

@ -0,0 +1,49 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Interpreters/SelectQueryOptions.h>
#include <Analyzer/QueryTreePassManager.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Interpreters/Context_fwd.h>
#include <Planner/Planner.h>
namespace DB
{
class InterpreterSelectQueryAnalyzer : public IInterpreter, public WithContext
{
public:
/// Initialize interpreter with query AST
InterpreterSelectQueryAnalyzer(const ASTPtr & query_,
const SelectQueryOptions & select_query_options_,
ContextPtr context_);
/// Initialize interpreter with query tree
InterpreterSelectQueryAnalyzer(const QueryTreeNodePtr & query_tree_,
const SelectQueryOptions & select_query_options_,
ContextPtr context_);
Block getSampleBlock();
BlockIO execute() override;
QueryPlan && extractQueryPlan() &&;
bool supportsTransactions() const override { return true; }
bool ignoreLimits() const override { return select_query_options.ignore_limits; }
bool ignoreQuota() const override { return select_query_options.ignore_quota; }
void extendQueryLogElemImpl(QueryLogElement & elem, const ASTPtr &, ContextPtr) const override;
private:
ASTPtr query;
QueryTreeNodePtr query_tree;
SelectQueryOptions select_query_options;
Planner planner;
};
}

View File

@ -45,7 +45,7 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
SelectUnionModesSet current_set_of_modes;
bool distinct_found = false;
for (ssize_t i = union_modes.size() - 1; i >= 0; --i)
for (Int64 i = union_modes.size() - 1; i >= 0; --i)
{
current_set_of_modes.insert(union_modes[i]);
if (const auto * union_ast = typeid_cast<const ASTSelectWithUnionQuery *>(select_list[i + 1].get()))

Some files were not shown because too many files have changed in this diff Show More