diff --git a/src/Interpreters/ActionsVisitor.cpp b/src/Interpreters/ActionsVisitor.cpp index b382b26dcec..1c82bc62f24 100644 --- a/src/Interpreters/ActionsVisitor.cpp +++ b/src/Interpreters/ActionsVisitor.cpp @@ -549,7 +549,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data & data.addAction(ExpressionAction::copyColumn(arg->getColumnName(), result_name)); NameSet joined_columns; joined_columns.insert(result_name); - data.addAction(ExpressionAction::arrayJoin(joined_columns, false, data.context)); + data.addAction(ExpressionAction::arrayJoin(std::make_shared(joined_columns, false, data.context))); } return; diff --git a/src/Interpreters/ArrayJoinAction.h b/src/Interpreters/ArrayJoinAction.h index d70c0c14a15..be5be738bb9 100644 --- a/src/Interpreters/ArrayJoinAction.h +++ b/src/Interpreters/ArrayJoinAction.h @@ -12,8 +12,9 @@ class Context; class IFunctionOverloadResolver; using FunctionOverloadResolverPtr = std::shared_ptr; -struct ArrayJoinAction +class ArrayJoinAction { +public: NameSet columns; bool is_left = false; bool is_unaligned = false; @@ -32,4 +33,6 @@ struct ArrayJoinAction void finalize(NameSet & needed_columns, NameSet & unmodified_columns, NameSet & final_columns); }; +using ArrayJoinActionPtr = std::shared_ptr; + } diff --git a/src/Interpreters/ExpressionActions.cpp b/src/Interpreters/ExpressionActions.cpp index 0e1d0c51704..f35e6266110 100644 --- a/src/Interpreters/ExpressionActions.cpp +++ b/src/Interpreters/ExpressionActions.cpp @@ -143,11 +143,11 @@ ExpressionAction ExpressionAction::addAliases(const NamesWithAliases & aliased_c return a; } -ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context) +ExpressionAction ExpressionAction::arrayJoin(ArrayJoinActionPtr array_join_) { ExpressionAction a; a.type = ARRAY_JOIN; - a.array_join = std::make_shared(array_joined_columns, array_join_is_left, context); + a.array_join = std::move(array_join_); return a; } diff --git a/src/Interpreters/ExpressionActions.h b/src/Interpreters/ExpressionActions.h index 3697dc02ad3..372a17f58df 100644 --- a/src/Interpreters/ExpressionActions.h +++ b/src/Interpreters/ExpressionActions.h @@ -98,7 +98,7 @@ public: bool is_function_compiled = false; /// For ARRAY JOIN - std::shared_ptr array_join; + ArrayJoinActionPtr array_join; /// For JOIN std::shared_ptr table_join; @@ -117,7 +117,7 @@ public: static ExpressionAction project(const NamesWithAliases & projected_columns_); static ExpressionAction project(const Names & projected_columns_); static ExpressionAction addAliases(const NamesWithAliases & aliased_columns_); - static ExpressionAction arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context); + static ExpressionAction arrayJoin(ArrayJoinActionPtr array_join_); static ExpressionAction ordinaryJoin(std::shared_ptr table_join, JoinPtr join); /// Which columns necessary to perform this action. diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index acbf6255fba..7cabef1df9c 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -173,25 +173,33 @@ void ExpressionAnalyzer::analyzeAggregation() if (select_query) { + NamesAndTypesList array_join_columns; + bool is_array_join_left; - ASTPtr array_join_expression_list = select_query->arrayJoinExpressionList(is_array_join_left); - if (array_join_expression_list) + if (ASTPtr array_join_expression_list = select_query->arrayJoinExpressionList(is_array_join_left)) { getRootActionsNoMakeSet(array_join_expression_list, true, temp_actions, false); - addMultipleArrayJoinAction(temp_actions, is_array_join_left); + if (auto array_join = addMultipleArrayJoinAction(temp_actions, is_array_join_left)) + temp_actions->add(ExpressionAction::arrayJoin(array_join)); - array_join_columns.clear(); for (auto & column : temp_actions->getSampleBlock().getNamesAndTypesList()) if (syntax->array_join_result_to_source.count(column.name)) array_join_columns.emplace_back(column); } + columns_after_array_join = sourceColumns(); + columns_after_array_join.insert(columns_after_array_join.end(), array_join_columns.begin(), array_join_columns.end()); + const ASTTablesInSelectQueryElement * join = select_query->join(); if (join) { getRootActionsNoMakeSet(analyzedJoin().leftKeysList(), true, temp_actions, false); addJoinAction(temp_actions); } + + columns_after_join = columns_after_array_join; + const auto & added_by_join = analyzedJoin().columnsAddedByJoin(); + columns_after_join.insert(columns_after_join.end(), added_by_join.begin(), added_by_join.end()); } has_aggregation = makeAggregateDescriptions(temp_actions); @@ -281,16 +289,6 @@ void ExpressionAnalyzer::initGlobalSubqueriesAndExternalTables(bool do_global) } -NamesAndTypesList ExpressionAnalyzer::sourceWithJoinedColumns() const -{ - auto result_columns = sourceColumns(); - result_columns.insert(result_columns.end(), array_join_columns.begin(), array_join_columns.end()); - result_columns.insert(result_columns.end(), - analyzedJoin().columnsAddedByJoin().begin(), analyzedJoin().columnsAddedByJoin().end()); - return result_columns; -} - - void SelectQueryExpressionAnalyzer::tryMakeSetForIndexFromSubquery(const ASTPtr & subquery_or_table_name) { auto set_key = PreparedSetKey::forSubquery(*subquery_or_table_name); @@ -374,7 +372,7 @@ void SelectQueryExpressionAnalyzer::makeSetsForIndex(const ASTPtr & node) } else { - ExpressionActionsPtr temp_actions = std::make_shared(sourceWithJoinedColumns(), context); + ExpressionActionsPtr temp_actions = std::make_shared(columns_after_join, context); getRootActions(left_in_operand, true, temp_actions); Block sample_block_with_calculated_columns = temp_actions->getSampleBlock(); @@ -455,7 +453,7 @@ const ASTSelectQuery * SelectQueryExpressionAnalyzer::getAggregatingQuery() cons } /// "Big" ARRAY JOIN. -void ExpressionAnalyzer::addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool array_join_is_left) const +ArrayJoinActionPtr ExpressionAnalyzer::addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool array_join_is_left) const { NameSet result_columns; for (const auto & result_source : syntax->array_join_result_to_source) @@ -468,25 +466,27 @@ void ExpressionAnalyzer::addMultipleArrayJoinAction(ExpressionActionsPtr & actio result_columns.insert(result_source.first); } - actions->add(ExpressionAction::arrayJoin(result_columns, array_join_is_left, context)); + return std::make_shared(result_columns, array_join_is_left, context); } -bool SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, bool only_types) +ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, bool only_types) { const auto * select_query = getSelectQuery(); bool is_array_join_left; ASTPtr array_join_expression_list = select_query->arrayJoinExpressionList(is_array_join_left); if (!array_join_expression_list) - return false; + return nullptr; ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns()); getRootActions(array_join_expression_list, only_types, step.actions); - addMultipleArrayJoinAction(step.actions, is_array_join_left); + auto array_join = addMultipleArrayJoinAction(step.actions, is_array_join_left); + for (const auto & column : array_join->columns) + step.required_output.emplace_back(column); - return true; + return array_join; } void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions, JoinPtr join) const @@ -496,7 +496,7 @@ void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions, JoinPtr j bool SelectQueryExpressionAnalyzer::appendJoinLeftKeys(ExpressionActionsChain & chain, bool only_types) { - ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns()); + ExpressionActionsChain::Step & step = chain.lastStep(columns_after_array_join); getRootActions(analyzedJoin().leftKeysList(), only_types, step.actions); return true; @@ -506,7 +506,7 @@ bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain) { JoinPtr table_join = makeTableJoin(*syntax->ast_join); - ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns()); + ExpressionActionsChain::Step & step = chain.lastStep(columns_after_array_join); addJoinAction(step.actions, table_join); return true; @@ -720,7 +720,7 @@ bool SelectQueryExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain, if (!select_query->where()) return false; - ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns()); + ExpressionActionsChain::Step & step = chain.lastStep(columns_after_join); auto where_column_name = select_query->where()->getColumnName(); step.required_output.push_back(where_column_name); @@ -744,7 +744,7 @@ bool SelectQueryExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain if (!select_query->groupBy()) return false; - ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns()); + ExpressionActionsChain::Step & step = chain.lastStep(columns_after_join); ASTs asts = select_query->groupBy()->children; for (const auto & ast : asts) @@ -755,10 +755,9 @@ bool SelectQueryExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain if (optimize_aggregation_in_order) { - auto all_columns = sourceWithJoinedColumns(); for (auto & child : asts) { - group_by_elements_actions.emplace_back(std::make_shared(all_columns, context)); + group_by_elements_actions.emplace_back(std::make_shared(columns_after_join, context)); getRootActions(child, only_types, group_by_elements_actions.back()); } } @@ -770,7 +769,7 @@ void SelectQueryExpressionAnalyzer::appendAggregateFunctionsArguments(Expression { const auto * select_query = getAggregatingQuery(); - ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns()); + ExpressionActionsChain::Step & step = chain.lastStep(columns_after_join); for (const auto & desc : aggregate_descriptions) for (const auto & name : desc.argument_names) @@ -844,10 +843,9 @@ bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain if (optimize_read_in_order) { - auto all_columns = sourceWithJoinedColumns(); for (auto & child : select_query->orderBy()->children) { - order_by_elements_actions.emplace_back(std::make_shared(all_columns, context)); + order_by_elements_actions.emplace_back(std::make_shared(columns_after_join, context)); getRootActions(child, only_types, order_by_elements_actions.back()); } } @@ -1093,7 +1091,13 @@ ExpressionAnalysisResult::ExpressionAnalysisResult( chain.addStep(); } - query_analyzer.appendArrayJoin(chain, only_types || !first_stage); + array_join = query_analyzer.appendArrayJoin(chain, only_types || !first_stage); + if (array_join) + { + before_array_join = chain.getLastActions(true); + if (before_array_join) + chain.addStep(); + } if (query_analyzer.hasTableJoin()) { diff --git a/src/Interpreters/ExpressionAnalyzer.h b/src/Interpreters/ExpressionAnalyzer.h index a37235f2f77..7ffe06ebd9e 100644 --- a/src/Interpreters/ExpressionAnalyzer.h +++ b/src/Interpreters/ExpressionAnalyzer.h @@ -34,6 +34,9 @@ struct ASTTablesInSelectQueryElement; struct StorageInMemoryMetadata; using StorageMetadataPtr = std::shared_ptr; +class ArrayJoinAction; +using ArrayJoinActionPtr = std::shared_ptr; + /// Create columns in block or return false if not possible bool sanitizeBlock(Block & block, bool throw_if_cannot_create_column = false); @@ -43,9 +46,12 @@ struct ExpressionAnalyzerData SubqueriesForSets subqueries_for_sets; PreparedSets prepared_sets; + /// Columns after ARRAY JOIN. It there is no ARRAY JOIN, it's source_columns. + NamesAndTypesList columns_after_array_join; + /// Columns after Columns after ARRAY JOIN and JOIN. If there is no JOIN, it's columns_after_array_join. + NamesAndTypesList columns_after_join; /// Columns after ARRAY JOIN, JOIN, and/or aggregation. NamesAndTypesList aggregated_columns; - NamesAndTypesList array_join_columns; bool has_aggregation = false; NamesAndTypesList aggregation_keys; @@ -128,12 +134,10 @@ protected: const TableJoin & analyzedJoin() const { return *syntax->analyzed_join; } const NamesAndTypesList & sourceColumns() const { return syntax->required_source_columns; } const std::vector & aggregates() const { return syntax->aggregates; } - NamesAndTypesList sourceWithJoinedColumns() const; - /// Find global subqueries in the GLOBAL IN/JOIN sections. Fills in external_tables. void initGlobalSubqueriesAndExternalTables(bool do_global); - void addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool is_left) const; + ArrayJoinActionPtr addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool is_left) const; void addJoinAction(ExpressionActionsPtr & actions, JoinPtr = {}) const; @@ -175,6 +179,8 @@ struct ExpressionAnalysisResult bool optimize_read_in_order = false; bool optimize_aggregation_in_order = false; + ExpressionActionsPtr before_array_join; + ArrayJoinActionPtr array_join; ExpressionActionsPtr before_join; ExpressionActionsPtr join; ExpressionActionsPtr before_where; @@ -305,7 +311,7 @@ private: */ /// Before aggregation: - bool appendArrayJoin(ExpressionActionsChain & chain, bool only_types); + ArrayJoinActionPtr appendArrayJoin(ExpressionActionsChain & chain, bool only_types); bool appendJoinLeftKeys(ExpressionActionsChain & chain, bool only_types); bool appendJoin(ExpressionActionsChain & chain); /// Add preliminary rows filtration. Actions are created in other expression analyzer to prevent any possible alias injection. diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 19a4e998dc7..604bf55649a 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -862,6 +863,25 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, const BlockInpu query_plan.addStep(std::move(row_level_security_step)); } + if (expressions.before_array_join) + { + QueryPlanStepPtr before_array_join_step = std::make_unique( + query_plan.getCurrentDataStream(), + expressions.before_array_join); + before_array_join_step->setStepDescription("Before ARRAY JOIN"); + query_plan.addStep(std::move(before_array_join_step)); + } + + if (expressions.array_join) + { + QueryPlanStepPtr array_join_step = std::make_unique( + query_plan.getCurrentDataStream(), + expressions.array_join); + + array_join_step->setStepDescription("ARRAY JOIN"); + query_plan.addStep(std::move(array_join_step)); + } + if (expressions.before_join) { QueryPlanStepPtr before_join_step = std::make_unique( diff --git a/src/Processors/QueryPlan/ArrayJoinStep.cpp b/src/Processors/QueryPlan/ArrayJoinStep.cpp new file mode 100644 index 00000000000..2948d4cc842 --- /dev/null +++ b/src/Processors/QueryPlan/ArrayJoinStep.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include + +namespace DB +{ + +static ITransformingStep::Traits getTraits() +{ + return ITransformingStep::Traits + { + { + .preserves_distinct_columns = false, + .returns_single_stream = false, + .preserves_number_of_streams = true, + .preserves_sorting = false, + }, + { + .preserves_number_of_rows = false, + } + }; +} + +ArrayJoinStep::ArrayJoinStep(const DataStream & input_stream_, ArrayJoinActionPtr array_join_) + : ITransformingStep( + input_stream_, + ArrayJoinTransform::transformHeader(input_stream_.header, array_join_), + getTraits()) + , array_join(std::move(array_join_)) +{ +} + +void ArrayJoinStep::transformPipeline(QueryPipeline & pipeline) +{ + pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type) + { + bool on_totals = stream_type == QueryPipeline::StreamType::Totals; + return std::make_shared(header, array_join, on_totals); + }); +} + +void ArrayJoinStep::describeActions(FormatSettings & settings) const +{ + String prefix(settings.offset, ' '); + bool first = true; + + settings.out << prefix << (array_join->is_left ? "LEFT " : "") << "ARRAY JOIN "; + for (const auto & column : array_join->columns) + { + if (!first) + settings.out << ", "; + first = false; + + + settings.out << column; + } + settings.out << '\n'; +} + +} diff --git a/src/Processors/QueryPlan/ArrayJoinStep.h b/src/Processors/QueryPlan/ArrayJoinStep.h new file mode 100644 index 00000000000..9a9504a5d54 --- /dev/null +++ b/src/Processors/QueryPlan/ArrayJoinStep.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace DB +{ + +class ArrayJoinAction; +using ArrayJoinActionPtr = std::shared_ptr; + +class ArrayJoinStep : public ITransformingStep +{ +public: + explicit ArrayJoinStep(const DataStream & input_stream_, ArrayJoinActionPtr array_join_); + String getName() const override { return "Expression"; } + + void transformPipeline(QueryPipeline & pipeline) override; + + void describeActions(FormatSettings & settings) const override; + +private: + ArrayJoinActionPtr array_join; +}; + +} diff --git a/src/Processors/Transforms/ArrayJoinTransform.cpp b/src/Processors/Transforms/ArrayJoinTransform.cpp new file mode 100644 index 00000000000..ba8e4949f7c --- /dev/null +++ b/src/Processors/Transforms/ArrayJoinTransform.cpp @@ -0,0 +1,37 @@ +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +Block ArrayJoinTransform::transformHeader(Block header, const ArrayJoinActionPtr & array_join) +{ + array_join->execute(header, true); + return header; +} + +ArrayJoinTransform::ArrayJoinTransform( + const Block & header_, + ArrayJoinActionPtr array_join_, + bool /*on_totals_*/) + : ISimpleTransform(header_, transformHeader(header_, array_join_), false) + , array_join(std::move(array_join_)) +{ + /// TODO +// if (on_totals_) +// throw Exception("ARRAY JOIN is not supported for totals", ErrorCodes::LOGICAL_ERROR); +} + +void ArrayJoinTransform::transform(Chunk & chunk) +{ + auto block = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns()); + array_join->execute(block, false); + chunk.setColumns(block.getColumns(), block.rows()); +} + +} diff --git a/src/Processors/Transforms/ArrayJoinTransform.h b/src/Processors/Transforms/ArrayJoinTransform.h new file mode 100644 index 00000000000..0d81d5e458c --- /dev/null +++ b/src/Processors/Transforms/ArrayJoinTransform.h @@ -0,0 +1,30 @@ +#pragma once +#include + +namespace DB +{ + +class ArrayJoinAction; +using ArrayJoinActionPtr = std::shared_ptr; + +/// Execute ARRAY JOIN +class ArrayJoinTransform : public ISimpleTransform +{ +public: + ArrayJoinTransform( + const Block & header_, + ArrayJoinActionPtr array_join_, + bool on_totals_ = false); + + String getName() const override { return "ArrayJoinTransform"; } + + static Block transformHeader(Block header, const ArrayJoinActionPtr & array_join); + +protected: + void transform(Chunk & chunk) override; + +private: + ArrayJoinActionPtr array_join; +}; + +} diff --git a/src/Processors/ya.make b/src/Processors/ya.make index 4c25ad5bf3f..45b9986a9bb 100644 --- a/src/Processors/ya.make +++ b/src/Processors/ya.make @@ -88,6 +88,7 @@ SRCS( QueryPipeline.cpp QueryPlan/AddingDelayedSourceStep.cpp QueryPlan/AggregatingStep.cpp + QueryPlan/ArrayJoinStep.cpp QueryPlan/ConvertingStep.cpp QueryPlan/CreatingSetsStep.cpp QueryPlan/CubeStep.cpp @@ -124,6 +125,7 @@ SRCS( Transforms/AddingSelectorTransform.cpp Transforms/AggregatingInOrderTransform.cpp Transforms/AggregatingTransform.cpp + Transforms/ArrayJoinTransform.cpp Transforms/ConvertingTransform.cpp Transforms/CopyTransform.cpp Transforms/CreatingSetsTransform.cpp