diff --git a/src/Analyzer/FunctionNode.cpp b/src/Analyzer/FunctionNode.cpp index 54e68acab55..06f0f248847 100644 --- a/src/Analyzer/FunctionNode.cpp +++ b/src/Analyzer/FunctionNode.cpp @@ -14,6 +14,14 @@ namespace DB { +FunctionNode::FunctionNode(String function_name_) + : function_name(function_name_) +{ + children.resize(2); + children[parameters_child_index] = std::make_shared(); + children[arguments_child_index] = std::make_shared(); +} + void FunctionNode::resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value) { aggregate_function = nullptr; diff --git a/src/Analyzer/FunctionNode.h b/src/Analyzer/FunctionNode.h index 2020afe89df..7ffb2af5c5d 100644 --- a/src/Analyzer/FunctionNode.h +++ b/src/Analyzer/FunctionNode.h @@ -36,13 +36,7 @@ public: /** Initialize function node with function name. * Later during query analysis path function must be resolved. */ - explicit FunctionNode(String function_name_) - : function_name(function_name_) - { - children.resize(2); - children[parameters_child_index] = std::make_shared(); - children[arguments_child_index] = std::make_shared(); - } + explicit FunctionNode(String function_name_); /// Get name const String & getFunctionName() const diff --git a/src/Analyzer/IQueryTreeNode.cpp b/src/Analyzer/IQueryTreeNode.cpp index a69afa2eb4f..6643a254bab 100644 --- a/src/Analyzer/IQueryTreeNode.cpp +++ b/src/Analyzer/IQueryTreeNode.cpp @@ -30,6 +30,7 @@ const char * toString(QueryTreeNodeType type) case QueryTreeNodeType::FUNCTION: return "FUNCTION"; case QueryTreeNodeType::COLUMN: return "COLUMN"; case QueryTreeNodeType::LAMBDA: return "LAMBDA"; + case QueryTreeNodeType::SORT_COLUMN: return "SORT_COLUMN"; case QueryTreeNodeType::TABLE: return "TABLE"; case QueryTreeNodeType::TABLE_FUNCTION: return "TABLE_FUNCTION"; case QueryTreeNodeType::QUERY: return "QUERY"; diff --git a/src/Analyzer/IQueryTreeNode.h b/src/Analyzer/IQueryTreeNode.h index 2343716e82b..38be6af5300 100644 --- a/src/Analyzer/IQueryTreeNode.h +++ b/src/Analyzer/IQueryTreeNode.h @@ -35,6 +35,7 @@ enum class QueryTreeNodeType FUNCTION, COLUMN, LAMBDA, + SORT_COLUMN, TABLE, TABLE_FUNCTION, QUERY, diff --git a/src/Analyzer/QueryAnalysisPass.cpp b/src/Analyzer/QueryAnalysisPass.cpp index 93acfc98907..2b4f019f0d4 100644 --- a/src/Analyzer/QueryAnalysisPass.cpp +++ b/src/Analyzer/QueryAnalysisPass.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -88,6 +89,7 @@ namespace ErrorCodes extern const int INCORRECT_ELEMENT_OF_SET; extern const int TYPE_MISMATCH; extern const int AMBIGUOUS_IDENTIFIER; + extern const int INVALID_WITH_FILL_EXPRESSION; } /** Query analyzer implementation overview. Please check documentation in QueryAnalysisPass.h before. @@ -812,6 +814,8 @@ private: void resolveExpressionNodeList(QueryTreeNodePtr & node_list, IdentifierResolveScope & scope, bool allow_lambda_expression, bool allow_table_expression); + void resolveSortColumnsNodeList(QueryTreeNodePtr & sort_columns_node_list, IdentifierResolveScope & scope); + void initializeQueryJoinTreeNode(QueryTreeNodePtr & join_tree_node, IdentifierResolveScope & scope); void initializeTableExpressionColumns(QueryTreeNodePtr & table_expression_node, IdentifierResolveScope & scope); @@ -2846,6 +2850,14 @@ void QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, IdentifierRes /// Lambda must be resolved by caller break; } + + case QueryTreeNodeType::SORT_COLUMN: + { + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Sort column {} is not allowed in expression. In scope {}", + node->formatASTForErrorMessage(), + scope.scope_node->formatASTForErrorMessage()); + } case QueryTreeNodeType::TABLE: { if (!allow_table_expression) @@ -2887,14 +2899,14 @@ void QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, IdentifierRes case QueryTreeNodeType::ARRAY_JOIN: { throw Exception(ErrorCodes::LOGICAL_ERROR, - "Array join is not allowed {} in expression. In scope {}", + "Array join {} is not allowed in expression. In scope {}", node->formatASTForErrorMessage(), scope.scope_node->formatASTForErrorMessage()); } case QueryTreeNodeType::JOIN: { throw Exception(ErrorCodes::LOGICAL_ERROR, - "Join is not allowed {} in expression. In scope {}", + "Join {} is not allowed in expression. In scope {}", node->formatASTForErrorMessage(), scope.scope_node->formatASTForErrorMessage()); } @@ -2958,6 +2970,58 @@ void QueryAnalyzer::resolveExpressionNodeList(QueryTreeNodePtr & node_list, Iden node_list = std::move(result_node_list); } +/** Resolve sort columns nodes list. + */ +void QueryAnalyzer::resolveSortColumnsNodeList(QueryTreeNodePtr & sort_columns_node_list, IdentifierResolveScope & scope) +{ + auto & sort_columns_node_list_typed = sort_columns_node_list->as(); + for (auto & node : sort_columns_node_list_typed.getNodes()) + { + auto & sort_column_node = node->as(); + resolveExpressionNode(sort_column_node.getExpression(), scope, false /*allow_lambda_expression*/, false /*allow_table_expression*/); + + if (sort_column_node.hasFillFrom()) + { + resolveExpressionNode(sort_column_node.getFillFrom(), scope, false /*allow_lambda_expression*/, false /*allow_table_expression*/); + + const auto * constant_node = sort_column_node.getFillFrom()->as(); + if (!constant_node || !isColumnedAsNumber(constant_node->getResultType())) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL FROM expression must be constant with numeric type. Actual {}. In scope {}", + sort_column_node.getFillFrom()->formatASTForErrorMessage(), + scope.scope_node->formatASTForErrorMessage()); + } + if (sort_column_node.hasFillTo()) + { + resolveExpressionNode(sort_column_node.getFillTo(), scope, false /*allow_lambda_expression*/, false /*allow_table_expression*/); + const auto * constant_node = sort_column_node.getFillTo()->as(); + if (!constant_node || !isColumnedAsNumber(constant_node->getResultType())) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL TO expression must be constant with numeric type. Actual {}. In scope {}", + sort_column_node.getFillFrom()->formatASTForErrorMessage(), + scope.scope_node->formatASTForErrorMessage()); + } + if (sort_column_node.hasFillStep()) + { + resolveExpressionNode(sort_column_node.getFillStep(), scope, false /*allow_lambda_expression*/, false /*allow_table_expression*/); + const auto * constant_node = sort_column_node.getFillStep()->as(); + if (!constant_node) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL TO expression must be constant with numeric or interval type. Actual {}. In scope {}", + sort_column_node.getFillStep()->formatASTForErrorMessage(), + scope.scope_node->formatASTForErrorMessage()); + + bool is_number = isColumnedAsNumber(constant_node->getResultType()); + bool is_interval = WhichDataType(constant_node->getResultType()).isInterval(); + if (!is_number && !is_interval) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL TO expression must be constant with numeric or interval type. Actual {}. In scope {}", + sort_column_node.getFillStep()->formatASTForErrorMessage(), + scope.scope_node->formatASTForErrorMessage()); + } + } +} + /** Initialize query join tree node. * * 1. Resolve identifiers. @@ -3447,7 +3511,7 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier QueryExpressionsAliasVisitor::Data data{scope}; QueryExpressionsAliasVisitor visitor(data); - if (!query_node_typed.getWith().getNodes().empty()) + if (query_node_typed.hasWith()) visitor.visit(query_node_typed.getWithNode()); if (!query_node_typed.getProjection().getNodes().empty()) @@ -3459,6 +3523,12 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier if (query_node_typed.getWhere()) visitor.visit(query_node_typed.getWhere()); + if (query_node_typed.hasGroupBy()) + visitor.visit(query_node_typed.getGroupByNode()); + + if (query_node_typed.hasOrderBy()) + visitor.visit(query_node_typed.getOrderByNode()); + /// Register CTE subqueries and remove them from WITH section auto & with_nodes = query_node_typed.getWith().getNodes(); @@ -3511,7 +3581,7 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier /// Resolve query node sections. - if (!query_node_typed.getWith().getNodes().empty()) + if (query_node_typed.hasWith()) resolveExpressionNodeList(query_node_typed.getWithNode(), scope, true /*allow_lambda_expression*/, false /*allow_table_expression*/); if (query_node_typed.getPrewhere()) @@ -3520,8 +3590,11 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier if (query_node_typed.getWhere()) resolveExpressionNode(query_node_typed.getWhere(), scope, false /*allow_lambda_expression*/, false /*allow_table_expression*/); - if (!query_node_typed.getGroupBy().getNodes().empty()) - resolveExpressionNodeList(query_node_typed.getGroupByNode(), scope, true /*allow_lambda_expression*/, false /*allow_table_expression*/); + if (query_node_typed.hasGroupBy()) + resolveExpressionNodeList(query_node_typed.getGroupByNode(), scope, false /*allow_lambda_expression*/, false /*allow_table_expression*/); + + if (query_node_typed.hasOrderBy()) + resolveSortColumnsNodeList(query_node_typed.getOrderByNode(), scope); resolveExpressionNodeList(query_node_typed.getProjectionNode(), scope, false /*allow_lambda_expression*/, false /*allow_table_expression*/); diff --git a/src/Analyzer/QueryNode.cpp b/src/Analyzer/QueryNode.cpp index 35bee05b85e..e2d3e490ad7 100644 --- a/src/Analyzer/QueryNode.cpp +++ b/src/Analyzer/QueryNode.cpp @@ -27,6 +27,7 @@ QueryNode::QueryNode() children[with_child_index] = std::make_shared(); children[projection_child_index] = std::make_shared(); children[group_by_child_index] = std::make_shared(); + children[order_by_child_index] = std::make_shared(); } NamesAndTypesList QueryNode::computeProjectionColumns() const @@ -56,12 +57,13 @@ String QueryNode::getName() const { WriteBufferFromOwnString buffer; - if (!getWith().getNodes().empty()) + if (hasWith()) { buffer << getWith().getName(); + buffer << ' '; } - buffer << " SELECT "; + buffer << "SELECT "; buffer << getProjection().getName(); if (getJoinTree()) @@ -82,10 +84,8 @@ String QueryNode::getName() const buffer << getWhere()->getName(); } - if (!getGroupBy().getNodes().empty()) - { + if (hasGroupBy()) buffer << getGroupBy().getName(); - } return buffer.str(); } @@ -104,7 +104,7 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s if (!cte_name.empty()) buffer << ", cte_name: " << cte_name; - if (!getWith().getNodes().empty()) + if (hasWith()) { buffer << '\n' << std::string(indent + 2, ' ') << "WITH\n"; getWith().dumpTreeImpl(buffer, format_state, indent + 4); @@ -132,11 +132,17 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s getWhere()->dumpTreeImpl(buffer, format_state, indent + 4); } - if (!getGroupBy().getNodes().empty()) + if (hasGroupBy()) { buffer << '\n' << std::string(indent + 2, ' ') << "GROUP BY\n"; getGroupBy().dumpTreeImpl(buffer, format_state, indent + 4); } + + if (hasOrderBy()) + { + buffer << '\n' << std::string(indent + 2, ' ') << "ORDER BY\n"; + getOrderBy().dumpTreeImpl(buffer, format_state, indent + 4); + } } bool QueryNode::isEqualImpl(const IQueryTreeNode & rhs) const @@ -161,7 +167,7 @@ ASTPtr QueryNode::toASTImpl() const auto select_query = std::make_shared(); select_query->distinct = is_distinct; - if (!getWith().getNodes().empty()) + if (hasWith()) select_query->setExpression(ASTSelectQuery::Expression::WITH, getWith().toAST()); select_query->setExpression(ASTSelectQuery::Expression::SELECT, children[projection_child_index]->toAST()); @@ -176,17 +182,20 @@ ASTPtr QueryNode::toASTImpl() const if (getWhere()) select_query->setExpression(ASTSelectQuery::Expression::WHERE, getWhere()->toAST()); - if (!getGroupBy().getNodes().empty()) + if (hasGroupBy()) select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, getGroupBy().toAST()); + if (hasOrderBy()) + select_query->setExpression(ASTSelectQuery::Expression::ORDER_BY, getOrderBy().toAST()); + auto result_select_query = std::make_shared(); result_select_query->union_mode = SelectUnionMode::Unspecified; auto list_of_selects = std::make_shared(); list_of_selects->children.push_back(std::move(select_query)); - result_select_query->list_of_selects = std::move(list_of_selects); - result_select_query->children.push_back(std::move(select_query)); + result_select_query->children.push_back(std::move(list_of_selects)); + result_select_query->list_of_selects = result_select_query->children.back(); if (is_subquery) { diff --git a/src/Analyzer/QueryNode.h b/src/Analyzer/QueryNode.h index 12684a8d059..e08c6a85287 100644 --- a/src/Analyzer/QueryNode.h +++ b/src/Analyzer/QueryNode.h @@ -57,6 +57,11 @@ public: return is_distinct; } + bool hasWith() const + { + return !getWith().getNodes().empty(); + } + const ListNode & getWith() const { return children[with_child_index]->as(); @@ -162,6 +167,31 @@ public: return children[group_by_child_index]; } + bool hasOrderBy() const + { + return !getOrderBy().getNodes().empty(); + } + + const ListNode & getOrderBy() const + { + return children[order_by_child_index]->as(); + } + + ListNode & getOrderBy() + { + return children[order_by_child_index]->as(); + } + + const QueryTreeNodePtr & getOrderByNode() const + { + return children[order_by_child_index]; + } + + QueryTreeNodePtr & getOrderByNode() + { + return children[order_by_child_index]; + } + /// Compute query node columns using projection section NamesAndTypesList computeProjectionColumns() const; diff --git a/src/Analyzer/QueryTreeBuilder.cpp b/src/Analyzer/QueryTreeBuilder.cpp index 45769d3029d..ee5da55859e 100644 --- a/src/Analyzer/QueryTreeBuilder.cpp +++ b/src/Analyzer/QueryTreeBuilder.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -29,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -78,6 +80,8 @@ private: QueryTreeNodePtr buildSelectExpression(const ASTPtr & select_query, bool is_subquery, const std::string & cte_name) const; + QueryTreeNodePtr buildSortColumnList(const ASTPtr & order_by_expression_list) const; + QueryTreeNodePtr buildExpressionList(const ASTPtr & expression_list) const; QueryTreeNodePtr buildExpression(const ASTPtr & expression) const; @@ -215,9 +219,50 @@ QueryTreeNodePtr QueryTreeBuilder::buildSelectExpression(const ASTPtr & select_q if (group_by_list) current_query_tree->getGroupByNode() = buildExpressionList(group_by_list); + auto select_order_by_list = select_query_typed.orderBy(); + if (select_order_by_list) + current_query_tree->getOrderByNode() = buildSortColumnList(select_order_by_list); + return current_query_tree; } +QueryTreeNodePtr QueryTreeBuilder::buildSortColumnList(const ASTPtr & order_by_expression_list) const +{ + auto list_node = std::make_shared(); + + auto & expression_list_typed = order_by_expression_list->as(); + list_node->getNodes().reserve(expression_list_typed.children.size()); + + for (auto & expression : expression_list_typed.children) + { + const auto & order_by_element = expression->as(); + + auto sort_direction = order_by_element.direction == 1 ? SortDirection::ASCENDING : SortDirection::DESCENDING; + std::optional nulls_sort_direction; + if (order_by_element.nulls_direction_was_explicitly_specified) + nulls_sort_direction = order_by_element.nulls_direction == 1 ? SortDirection::ASCENDING : SortDirection::DESCENDING; + + std::shared_ptr collator; + if (order_by_element.collation) + collator = std::make_shared(order_by_element.collation->as().value.get()); + + const auto & sort_expression_ast = order_by_element.children.at(0); + auto sort_expression = buildExpression(sort_expression_ast); + auto sort_column_node = std::make_shared(std::move(sort_expression), sort_direction, nulls_sort_direction, std::move(collator)); + + if (order_by_element.fill_from) + sort_column_node->getFillFrom() = buildExpression(order_by_element.fill_from); + if (order_by_element.fill_to) + sort_column_node->getFillTo() = buildExpression(order_by_element.fill_to); + if (order_by_element.fill_step) + sort_column_node->getFillStep() = buildExpression(order_by_element.fill_step); + + list_node->getNodes().push_back(std::move(sort_column_node)); + } + + return list_node; +} + QueryTreeNodePtr QueryTreeBuilder::buildExpressionList(const ASTPtr & expression_list) const { auto list_node = std::make_shared(); @@ -456,7 +501,7 @@ QueryTreeNodePtr QueryTreeBuilder::buildJoinTree(const ASTPtr & tables_in_select for (const auto & argument : function_arguments_list) { if (argument->as() || argument->as() || argument->as()) - node->getArguments().getNodes().push_back(buildSelectOrUnionExpression(argument, true /*is_subquery*/, {} /*cte_name*/)); + node->getArguments().getNodes().push_back(buildSelectOrUnionExpression(argument, false /*is_subquery*/, {} /*cte_name*/)); else node->getArguments().getNodes().push_back(buildExpression(argument)); } diff --git a/src/Analyzer/SortColumnNode.cpp b/src/Analyzer/SortColumnNode.cpp new file mode 100644 index 00000000000..c55946db4f8 --- /dev/null +++ b/src/Analyzer/SortColumnNode.cpp @@ -0,0 +1,158 @@ +#include + +#include + +#include +#include + +#include +#include +#include + +namespace DB +{ + +const char * toString(SortDirection sort_direction) +{ + switch (sort_direction) + { + case SortDirection::ASCENDING: return "ASCENDING"; + case SortDirection::DESCENDING: return "DESCENDING"; + } +} + +SortColumnNode::SortColumnNode(QueryTreeNodePtr expression_, + SortDirection sort_direction_, + std::optional nulls_sort_direction_, + std::shared_ptr collator_) + : sort_direction(sort_direction_) + , nulls_sort_direction(nulls_sort_direction_) + , collator(std::move(collator_)) +{ + children.resize(children_size); + children[sort_expression_child_index] = std::move(expression_); +} + +String SortColumnNode::getName() const +{ + String result = getExpression()->getName(); + + if (sort_direction == SortDirection::ASCENDING) + result += " ASC"; + else + result += " DESC"; + + if (nulls_sort_direction) + { + if (*nulls_sort_direction == SortDirection::ASCENDING) + result += " NULLS FIRST"; + else + result += " NULLS LAST"; + } + + if (hasWithFill()) + result += " WITH FILL"; + + if (hasFillFrom()) + result += " FROM " + getFillFrom()->getName(); + + if (hasFillStep()) + result += " STEP " + getFillStep()->getName(); + + if (hasFillTo()) + result += " TO " + getFillTo()->getName(); + + return result; +} + +void SortColumnNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const +{ + buffer << std::string(indent, ' ') << "SORT_COLUMN id: " << format_state.getNodeId(this); + + buffer << ", sort_direction: " << toString(sort_direction); + if (nulls_sort_direction) + buffer << ", nulls_sort_direction: " << toString(*nulls_sort_direction); + + if (collator) + buffer << ", collator: " << collator->getLocale(); + + buffer << '\n' << std::string(indent + 2, ' ') << "EXPRESSION\n"; + getExpression()->dumpTreeImpl(buffer, format_state, indent + 4); + + if (hasFillFrom()) + { + buffer << '\n' << std::string(indent + 2, ' ') << "FILL FROM\n"; + getFillFrom()->dumpTreeImpl(buffer, format_state, indent + 4); + } + + if (hasFillTo()) + { + buffer << '\n' << std::string(indent + 2, ' ') << "FILL TO\n"; + getFillTo()->dumpTreeImpl(buffer, format_state, indent + 4); + } + + if (hasFillStep()) + { + buffer << '\n' << std::string(indent + 2, ' ') << "FILL STEP\n"; + getFillStep()->dumpTreeImpl(buffer, format_state, indent + 4); + } +} + +bool SortColumnNode::isEqualImpl(const IQueryTreeNode & rhs) const +{ + const auto & rhs_typed = assert_cast(rhs); + if (sort_direction != rhs_typed.sort_direction || + nulls_sort_direction != rhs_typed.nulls_sort_direction) + return false; + + if (!collator && !rhs_typed.collator) + return true; + else if (collator && !rhs_typed.collator) + return false; + else if (!collator && rhs_typed.collator) + return false; + + return collator->getLocale() == rhs_typed.collator->getLocale(); +} + +void SortColumnNode::updateTreeHashImpl(HashState & hash_state) const +{ + hash_state.update(sort_direction); + hash_state.update(nulls_sort_direction); + + if (collator) + { + const auto & locale = collator->getLocale(); + + hash_state.update(locale.size()); + hash_state.update(locale); + } +} + +ASTPtr SortColumnNode::toASTImpl() const +{ + auto result = std::make_shared(); + result->direction = sort_direction == SortDirection::ASCENDING ? 1 : -1; + result->nulls_direction = result->direction; + if (nulls_sort_direction) + result->nulls_direction = *nulls_sort_direction == SortDirection::ASCENDING ? 1 : -1; + + result->nulls_direction_was_explicitly_specified = nulls_sort_direction.has_value(); + if (collator) + result->collation = std::make_shared(Field(collator->getLocale())); + + result->with_fill = hasWithFill(); + result->fill_from = hasFillFrom() ? getFillFrom()->toAST() : nullptr; + result->fill_to = hasFillTo() ? getFillTo()->toAST() : nullptr; + result->fill_step = hasFillStep() ? getFillStep()->toAST() : nullptr; + result->children.push_back(getExpression()->toAST()); + + return result; +} + +QueryTreeNodePtr SortColumnNode::cloneImpl() const +{ + return std::make_shared(nullptr, sort_direction, nulls_sort_direction, collator); +} + +} diff --git a/src/Analyzer/SortColumnNode.h b/src/Analyzer/SortColumnNode.h new file mode 100644 index 00000000000..04642a346a2 --- /dev/null +++ b/src/Analyzer/SortColumnNode.h @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include +#include + +namespace DB +{ + +/** Sort column node represents sort column descripion that is part of ORDER BY in query tree. + * Example: SELECT * FROM test_table ORDER BY sort_column_1, sort_column_2; + * Sort column optionally contain collation, fill from, fill to, and fill step. + */ +class SortColumnNode; +using SortColumnNodePtr = std::shared_ptr; + +enum class SortDirection +{ + ASCENDING = 0, + DESCENDING = 1 +}; + +const char * toString(SortDirection sort_direction); + +class SortColumnNode final : public IQueryTreeNode +{ +public: + /// Initialize sort column node with sort expression + explicit SortColumnNode(QueryTreeNodePtr expression_, + SortDirection sort_direction_ = SortDirection::ASCENDING, + std::optional nulls_sort_direction_ = {}, + std::shared_ptr collator_ = nullptr); + + /// Get sort expression + const QueryTreeNodePtr & getExpression() const + { + return children[sort_expression_child_index]; + } + + /// Get sort expression + QueryTreeNodePtr & getExpression() + { + return children[sort_expression_child_index]; + } + + /// Has with fill + bool hasWithFill() const + { + return hasFillFrom() || hasFillStep() || hasFillTo(); + } + + /// Has fill from + bool hasFillFrom() const + { + return children[fill_from_child_index] != nullptr; + } + + /// Get fill from + const QueryTreeNodePtr & getFillFrom() const + { + return children[fill_from_child_index]; + } + + /// Get fill from + QueryTreeNodePtr & getFillFrom() + { + return children[fill_from_child_index]; + } + + /// Has fill to + bool hasFillTo() const + { + return children[fill_to_child_index] != nullptr; + } + + /// Get fill to + const QueryTreeNodePtr & getFillTo() const + { + return children[fill_to_child_index]; + } + + /// Get fill to + QueryTreeNodePtr & getFillTo() + { + return children[fill_to_child_index]; + } + + bool hasFillStep() const + { + return children[fill_step_child_index] != nullptr; + } + + /// Get fill step + const QueryTreeNodePtr & getFillStep() const + { + return children[fill_step_child_index]; + } + + /// Get fill to + QueryTreeNodePtr & getFillStep() + { + return children[fill_step_child_index]; + } + + /// Get collator + const std::shared_ptr & getCollator() const + { + return collator; + } + + SortDirection getSortDirection() const + { + return sort_direction; + } + + std::optional getNullsSortDirection() const + { + return nulls_sort_direction; + } + + QueryTreeNodeType getNodeType() const override + { + return QueryTreeNodeType::SORT_COLUMN; + } + + String getName() const override; + + void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override; + + bool isEqualImpl(const IQueryTreeNode & rhs) const override; + + void updateTreeHashImpl(HashState & hash_state) const override; + +protected: + ASTPtr toASTImpl() const override; + + QueryTreeNodePtr cloneImpl() const override; + +private: + static constexpr size_t sort_expression_child_index = 0; + static constexpr size_t fill_from_child_index = 1; + static constexpr size_t fill_to_child_index = 2; + static constexpr size_t fill_step_child_index = 3; + static constexpr size_t children_size = fill_step_child_index + 1; + + SortDirection sort_direction = SortDirection::ASCENDING; + std::optional nulls_sort_direction; + std::shared_ptr collator; +}; + +} diff --git a/src/Planner/Planner.cpp b/src/Planner/Planner.cpp index 3ddd45ad67c..d764bcba0ef 100644 --- a/src/Planner/Planner.cpp +++ b/src/Planner/Planner.cpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include @@ -37,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -51,6 +54,7 @@ #include #include #include +#include #include namespace DB @@ -65,6 +69,7 @@ namespace ErrorCodes extern const int INVALID_JOIN_ON_EXPRESSION; extern const int ILLEGAL_AGGREGATION; extern const int NOT_AN_AGGREGATE; + extern const int INVALID_WITH_FILL_EXPRESSION; } /** ClickHouse query planner. @@ -88,6 +93,9 @@ namespace ErrorCodes * TODO: Support max streams * TODO: Support GROUPINS SETS, const aggregation keys, overflow row * TODO: Simplify buildings SETS for IN function + * TODO: Support interpolate, LIMIT BY. + * TODO: Support ORDER BY read in order optimization + * TODO: Support GROUP BY read in order optimization */ namespace @@ -1102,7 +1110,6 @@ void Planner::buildQueryPlanIfNeeded() ValidateGroupByColumnsVisitor::Data validate_group_by_visitor_data(aggregate_keys_set, *planner_context); ValidateGroupByColumnsVisitor validate_columns_visitor(validate_group_by_visitor_data); - validate_columns_visitor.visit(query_node.getProjectionNode()); auto aggregate_step = std::make_unique(std::move(group_by_actions), ActionsChainStep::AvailableOutputColumnsStrategy::OUTPUT_NODES, aggregates_columns); @@ -1110,6 +1117,34 @@ void Planner::buildQueryPlanIfNeeded() aggregate_step_index = actions_chain.getLastStepIndex(); } + std::optional before_order_by_step_index; + if (query_node.hasOrderBy()) + { + const auto * chain_available_output_columns = actions_chain.getLastStepAvailableOutputColumnsOrNull(); + const auto & order_by_input = chain_available_output_columns ? *chain_available_output_columns : query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName(); + + ActionsDAGPtr actions_dag = std::make_shared(order_by_input); + auto & actions_dag_outputs = actions_dag->getOutputs(); + actions_dag_outputs.clear(); + + PlannerActionsVisitor actions_visitor(planner_context); + + /** We add only sort column sort expression in before ORDER BY actions DAG. + * WITH fill expressions must be constant nodes. + */ + auto & order_by_node_list = query_node.getOrderByNode()->as(); + for (auto & sort_column_node : order_by_node_list.getNodes()) + { + auto & sort_column_node_typed = sort_column_node->as(); + auto expression_dag_index_nodes = actions_visitor.visit(actions_dag, sort_column_node_typed.getExpression()); + actions_dag_outputs.insert(actions_dag_outputs.end(), expression_dag_index_nodes.begin(), expression_dag_index_nodes.end()); + } + + auto actions_step_before_order_by = std::make_unique(std::move(actions_dag)); + actions_chain.addStep(std::move(actions_step_before_order_by)); + before_order_by_step_index = actions_chain.getLastStepIndex(); + } + const auto * chain_available_output_columns = actions_chain.getLastStepAvailableOutputColumnsOrNull(); const auto & projection_input = chain_available_output_columns ? *chain_available_output_columns : query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName(); auto projection_actions = convertExpressionNodeIntoDAG(query_node.getProjectionNode(), projection_input, planner_context); @@ -1152,13 +1187,13 @@ void Planner::buildQueryPlanIfNeeded() actions_chain.addStep(std::make_unique(std::move(projection_actions))); size_t projection_action_step_index = actions_chain.getLastStepIndex(); - // std::cout << "Chain dump before finalize" << std::endl; - // std::cout << actions_chain.dump() << std::endl; + std::cout << "Chain dump before finalize" << std::endl; + std::cout << actions_chain.dump() << std::endl; actions_chain.finalize(); - // std::cout << "Chain dump after finalize" << std::endl; - // std::cout << actions_chain.dump() << std::endl; + std::cout << "Chain dump after finalize" << std::endl; + std::cout << actions_chain.dump() << std::endl; if (where_action_step_index) { @@ -1283,6 +1318,55 @@ void Planner::buildQueryPlanIfNeeded() query_plan.addStep(std::move(distinct_step)); } + if (before_order_by_step_index) + { + auto & aggregate_actions_chain_node = actions_chain.at(*before_order_by_step_index); + auto expression_step_before_order_by = std::make_unique(query_plan.getCurrentDataStream(), + aggregate_actions_chain_node->getActions()); + expression_step_before_order_by->setStepDescription("Before ORDER BY"); + query_plan.addStep(std::move(expression_step_before_order_by)); + } + + if (query_node.hasOrderBy()) + { + SortDescription sort_description = extractSortDescription(query_node.getOrderByNode(), *planner_context); + String sort_description_dump = dumpSortDescription(sort_description); + + UInt64 limit = 0; + + const Settings & settings = planner_context->getQueryContext()->getSettingsRef(); + + /// Merge the sorted blocks. + auto sorting_step = std::make_unique( + query_plan.getCurrentDataStream(), + sort_description, + settings.max_block_size, + limit, + SizeLimits(settings.max_rows_to_sort, settings.max_bytes_to_sort, settings.sort_overflow_mode), + settings.max_bytes_before_remerge_sort, + settings.remerge_sort_lowered_memory_bytes_ratio, + settings.max_bytes_before_external_sort, + planner_context->getQueryContext()->getTemporaryVolume(), + settings.min_free_disk_space_for_temporary_data); + + sorting_step->setStepDescription("Sorting for ORDER BY"); + query_plan.addStep(std::move(sorting_step)); + + SortDescription fill_description; + for (auto & description : sort_description) + { + if (description.with_fill) + fill_description.push_back(description); + } + + if (!fill_description.empty()) + { + InterpolateDescriptionPtr interpolate_descr = nullptr; + auto filling_step = std::make_unique(query_plan.getCurrentDataStream(), std::move(fill_description), interpolate_descr); + query_plan.addStep(std::move(filling_step)); + } + } + auto projection_step = std::make_unique(query_plan.getCurrentDataStream(), actions_chain[projection_action_step_index]->getActions()); projection_step->setStepDescription("Projection"); query_plan.addStep(std::move(projection_step)); diff --git a/src/Planner/PlannerSorting.cpp b/src/Planner/PlannerSorting.cpp new file mode 100644 index 00000000000..869d1611500 --- /dev/null +++ b/src/Planner/PlannerSorting.cpp @@ -0,0 +1,155 @@ +#include + +#include + +#include + +#include + +#include +#include + +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_WITH_FILL_EXPRESSION; +} + +namespace +{ + +std::pair extractWithFillValue(const QueryTreeNodePtr & node) +{ + const auto & constant_node = node->as(); + + std::pair result; + result.first = constant_node.getConstantValue(); + result.second = constant_node.getResultType(); + + if (!isColumnedAsNumber(result.second)) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, "WITH FILL expression must be constant with numeric type"); + + return result; +} + +std::pair> extractWithFillStepValue(const QueryTreeNodePtr & node) +{ + const auto & constant_node = node->as(); + const auto & constant_node_result_type = constant_node.getResultType(); + if (const auto * type_interval = typeid_cast(constant_node_result_type.get())) + return std::make_pair(constant_node.getConstantValue(), type_interval->getKind()); + + if (!isColumnedAsNumber(constant_node_result_type)) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, "WITH FILL expression must be constant with numeric type"); + + return {constant_node.getConstantValue(), {}}; +} + +FillColumnDescription extractWithFillDescription(const SortColumnNode & sort_column_node) +{ + FillColumnDescription fill_column_description; + + if (sort_column_node.hasFillFrom()) + { + auto extract_result = extractWithFillValue(sort_column_node.getFillFrom()); + fill_column_description.fill_from = std::move(extract_result.first); + fill_column_description.fill_from_type = std::move(extract_result.second); + } + + if (sort_column_node.hasFillTo()) + { + auto extract_result = extractWithFillValue(sort_column_node.getFillTo()); + fill_column_description.fill_to = std::move(extract_result.first); + fill_column_description.fill_to_type = std::move(extract_result.second); + } + + if (sort_column_node.hasFillStep()) + { + auto extract_result = extractWithFillStepValue(sort_column_node.getFillStep()); + fill_column_description.fill_step = std::move(extract_result.first); + fill_column_description.step_kind = std::move(extract_result.second); + } + else + { + fill_column_description.fill_step = Field(sort_column_node.getSortDirection() == SortDirection::ASCENDING ? 1 : -1); + } + + if (applyVisitor(FieldVisitorAccurateEquals(), fill_column_description.fill_step, Field{0})) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL STEP value cannot be zero"); + + if (sort_column_node.getSortDirection() == SortDirection::ASCENDING) + { + if (applyVisitor(FieldVisitorAccurateLess(), fill_column_description.fill_step, Field{0})) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL STEP value cannot be negative for sorting in ascending direction"); + + if (!fill_column_description.fill_from.isNull() && !fill_column_description.fill_to.isNull() && + applyVisitor(FieldVisitorAccurateLess(), fill_column_description.fill_to, fill_column_description.fill_from)) + { + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL TO value cannot be less than FROM value for sorting in ascending direction"); + } + } + else + { + if (applyVisitor(FieldVisitorAccurateLess(), Field{0}, fill_column_description.fill_step)) + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL STEP value cannot be positive for sorting in descending direction"); + + if (!fill_column_description.fill_from.isNull() && !fill_column_description.fill_to.isNull() && + applyVisitor(FieldVisitorAccurateLess(), fill_column_description.fill_from, fill_column_description.fill_to)) + { + throw Exception(ErrorCodes::INVALID_WITH_FILL_EXPRESSION, + "WITH FILL FROM value cannot be less than TO value for sorting in descending direction"); + } + } + + return fill_column_description; +} + +} + +SortDescription extractSortDescription(const QueryTreeNodePtr & order_by_node, const PlannerContext & planner_context) +{ + auto & order_by_list_node = order_by_node->as(); + + SortDescription sort_column_description; + sort_column_description.reserve(order_by_list_node.getNodes().size()); + + for (const auto & sort_column_node : order_by_list_node.getNodes()) + { + auto & sort_column_node_typed = sort_column_node->as(); + + auto column_name = calculateActionsDAGNodeName(sort_column_node_typed.getExpression().get(), planner_context); + std::shared_ptr collator = sort_column_node_typed.getCollator(); + int direction = sort_column_node_typed.getSortDirection() == SortDirection::ASCENDING ? 1 : -1; + int nulls_direction = direction; + + auto nulls_sort_direction = sort_column_node_typed.getNullsSortDirection(); + if (nulls_sort_direction) + nulls_direction = *nulls_sort_direction == SortDirection::ASCENDING ? 1 : -1; + + if (sort_column_node_typed.hasWithFill()) + { + FillColumnDescription fill_description = extractWithFillDescription(sort_column_node_typed); + sort_column_description.emplace_back(column_name, direction, nulls_direction, collator, true /*with_fill*/, fill_description); + } + else + { + sort_column_description.emplace_back(column_name, direction, nulls_direction, collator); + } + } + + const auto & settings = planner_context.getQueryContext()->getSettingsRef(); + sort_column_description.compile_sort_description = settings.compile_sort_description; + sort_column_description.min_count_to_compile_sort_description = settings.min_count_to_compile_sort_description; + + return sort_column_description; +} + +} diff --git a/src/Planner/PlannerSorting.h b/src/Planner/PlannerSorting.h new file mode 100644 index 00000000000..ae3d2d1acfe --- /dev/null +++ b/src/Planner/PlannerSorting.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +#include + +namespace DB +{ + +/// Extract sort description from query order by node +SortDescription extractSortDescription(const QueryTreeNodePtr & order_by_node, const PlannerContext & planner_context); + +} + diff --git a/src/TableFunctions/TableFunctionViewIfPermitted.cpp b/src/TableFunctions/TableFunctionViewIfPermitted.cpp index 176db0915d4..ba3d2cb9d16 100644 --- a/src/TableFunctions/TableFunctionViewIfPermitted.cpp +++ b/src/TableFunctions/TableFunctionViewIfPermitted.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -39,14 +38,7 @@ void TableFunctionViewIfPermitted::parseArguments(const ASTPtr & ast_function, C getName()); const auto & arguments = function->arguments->children; - auto select_argument = arguments[0]; - auto * subquery = arguments[0]->as(); - - if (subquery) - select_argument = subquery->children[0]; - - auto * select = select_argument->as(); - + auto * select = arguments[0]->as(); if (!select) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Table function '{}' requires a SELECT query as its first argument", getName()); create.set(create.select, select->clone());