From 7339ff156e824c379f4dbfccb74546e7686b16e5 Mon Sep 17 00:00:00 2001 From: Nikolai Kochetov Date: Thu, 5 Sep 2024 13:22:45 +0000 Subject: [PATCH] Refactor ArrayJoin step. --- src/Functions/FunctionStringOrArrayToT.h | 3 +- src/Functions/array/arrayResize.cpp | 176 ++++++++---------- src/Functions/array/emptyArrayToSingle.cpp | 39 +--- src/Functions/array/length.cpp | 60 +----- src/Interpreters/ArrayJoinAction.cpp | 63 +++++-- src/Interpreters/ArrayJoinAction.h | 6 +- src/Interpreters/ExpressionAnalyzer.cpp | 5 +- src/Interpreters/InterpreterSelectQuery.cpp | 8 +- src/Planner/PlannerJoinTree.cpp | 10 +- src/Processors/QueryPlan/ArrayJoinStep.cpp | 20 +- src/Processors/QueryPlan/ArrayJoinStep.h | 10 +- .../Optimizations/filterPushDown.cpp | 3 +- .../Optimizations/liftUpArrayJoin.cpp | 4 +- .../Optimizations/optimizeReadInOrder.cpp | 4 +- .../Transforms/ArrayJoinTransform.cpp | 6 +- .../Transforms/ArrayJoinTransform.h | 2 +- 16 files changed, 185 insertions(+), 234 deletions(-) diff --git a/src/Functions/FunctionStringOrArrayToT.h b/src/Functions/FunctionStringOrArrayToT.h index 40f780d82a8..cd98e0f5875 100644 --- a/src/Functions/FunctionStringOrArrayToT.h +++ b/src/Functions/FunctionStringOrArrayToT.h @@ -27,7 +27,8 @@ class FunctionStringOrArrayToT : public IFunction { public: static constexpr auto name = Name::name; - static FunctionPtr create(ContextPtr) + static FunctionPtr create(ContextPtr) { return createImpl(); } + static FunctionPtr createImpl() { return std::make_shared(); } diff --git a/src/Functions/array/arrayResize.cpp b/src/Functions/array/arrayResize.cpp index 8f4ea69fc5d..fe928f22d38 100644 --- a/src/Functions/array/arrayResize.cpp +++ b/src/Functions/array/arrayResize.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -21,117 +21,99 @@ namespace ErrorCodes extern const int ILLEGAL_TYPE_OF_ARGUMENT; } -class FunctionArrayResize : public IFunction +DataTypePtr FunctionArrayResize::getReturnTypeImpl(const DataTypes & arguments) const { -public: - static constexpr auto name = "arrayResize"; - static FunctionPtr create(ContextPtr) { return std::make_shared(); } + const size_t number_of_arguments = arguments.size(); - String getName() const override { return name; } + if (number_of_arguments < 2 || number_of_arguments > 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", + getName(), number_of_arguments); - bool isVariadic() const override { return true; } - size_t getNumberOfArguments() const override { return 0; } + if (arguments[0]->onlyNull()) + return arguments[0]; - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + const auto * array_type = typeid_cast(arguments[0].get()); + if (!array_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "First argument for function {} must be an array but it has type {}.", + getName(), arguments[0]->getName()); - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + if (WhichDataType(array_type->getNestedType()).isNothing()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} cannot resize {}", getName(), array_type->getName()); + + if (!isInteger(removeNullable(arguments[1])) && !arguments[1]->onlyNull()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Argument {} for function {} must be integer but it has type {}.", + toString(1), getName(), arguments[1]->getName()); + + if (number_of_arguments == 2) + return arguments[0]; + else /* if (number_of_arguments == 3) */ + return std::make_shared(getLeastSupertype(DataTypes{array_type->getNestedType(), arguments[2]})); +} + +ColumnPtr FunctionArrayResize::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t input_rows_count) const +{ + if (return_type->onlyNull()) + return return_type->createColumnConstWithDefaultValue(input_rows_count); + + auto result_column = return_type->createColumn(); + + auto array_column = arguments[0].column; + auto size_column = arguments[1].column; + + if (!arguments[0].type->equals(*return_type)) + array_column = castColumn(arguments[0], return_type); + + const DataTypePtr & return_nested_type = typeid_cast(*return_type).getNestedType(); + size_t size = array_column->size(); + + ColumnPtr appended_column; + if (arguments.size() == 3) { - const size_t number_of_arguments = arguments.size(); + appended_column = arguments[2].column; + if (!arguments[2].type->equals(*return_nested_type)) + appended_column = castColumn(arguments[2], return_nested_type); + } + else + appended_column = return_nested_type->createColumnConstWithDefaultValue(size); - if (number_of_arguments < 2 || number_of_arguments > 3) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", - getName(), number_of_arguments); + std::unique_ptr array_source; + std::unique_ptr value_source; - if (arguments[0]->onlyNull()) - return arguments[0]; + bool is_const = false; - const auto * array_type = typeid_cast(arguments[0].get()); - if (!array_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "First argument for function {} must be an array but it has type {}.", - getName(), arguments[0]->getName()); - - if (WhichDataType(array_type->getNestedType()).isNothing()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} cannot resize {}", getName(), array_type->getName()); - - if (!isInteger(removeNullable(arguments[1])) && !arguments[1]->onlyNull()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Argument {} for function {} must be integer but it has type {}.", - toString(1), getName(), arguments[1]->getName()); - - if (number_of_arguments == 2) - return arguments[0]; - else /* if (number_of_arguments == 3) */ - return std::make_shared(getLeastSupertype(DataTypes{array_type->getNestedType(), arguments[2]})); + if (const auto * const_array_column = typeid_cast(array_column.get())) + { + is_const = true; + array_column = const_array_column->getDataColumnPtr(); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t input_rows_count) const override + if (const auto * argument_column_array = typeid_cast(array_column.get())) + array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size); + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "First arguments for function {} must be array.", getName()); + + + bool is_appended_const = false; + if (const auto * const_appended_column = typeid_cast(appended_column.get())) { - if (return_type->onlyNull()) - return return_type->createColumnConstWithDefaultValue(input_rows_count); - - auto result_column = return_type->createColumn(); - - auto array_column = arguments[0].column; - auto size_column = arguments[1].column; - - if (!arguments[0].type->equals(*return_type)) - array_column = castColumn(arguments[0], return_type); - - const DataTypePtr & return_nested_type = typeid_cast(*return_type).getNestedType(); - size_t size = array_column->size(); - - ColumnPtr appended_column; - if (arguments.size() == 3) - { - appended_column = arguments[2].column; - if (!arguments[2].type->equals(*return_nested_type)) - appended_column = castColumn(arguments[2], return_nested_type); - } - else - appended_column = return_nested_type->createColumnConstWithDefaultValue(size); - - std::unique_ptr array_source; - std::unique_ptr value_source; - - bool is_const = false; - - if (const auto * const_array_column = typeid_cast(array_column.get())) - { - is_const = true; - array_column = const_array_column->getDataColumnPtr(); - } - - if (const auto * argument_column_array = typeid_cast(array_column.get())) - array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size); - else - throw Exception(ErrorCodes::LOGICAL_ERROR, "First arguments for function {} must be array.", getName()); - - - bool is_appended_const = false; - if (const auto * const_appended_column = typeid_cast(appended_column.get())) - { - is_appended_const = true; - appended_column = const_appended_column->getDataColumnPtr(); - } - - value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size); - - auto sink = GatherUtils::createArraySink(typeid_cast(*result_column), size); - - if (isColumnConst(*size_column)) - GatherUtils::resizeConstantSize(*array_source, *value_source, *sink, size_column->getInt(0)); - else - GatherUtils::resizeDynamicSize(*array_source, *value_source, *sink, *size_column); - - return result_column; + is_appended_const = true; + appended_column = const_appended_column->getDataColumnPtr(); } - bool useDefaultImplementationForConstants() const override { return true; } - bool useDefaultImplementationForNulls() const override { return false; } -}; + value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size); + auto sink = GatherUtils::createArraySink(typeid_cast(*result_column), size); + + if (isColumnConst(*size_column)) + GatherUtils::resizeConstantSize(*array_source, *value_source, *sink, size_column->getInt(0)); + else + GatherUtils::resizeDynamicSize(*array_source, *value_source, *sink, *size_column); + + return result_column; +} REGISTER_FUNCTION(ArrayResize) { diff --git a/src/Functions/array/emptyArrayToSingle.cpp b/src/Functions/array/emptyArrayToSingle.cpp index 2071abf9911..5699a4024a1 100644 --- a/src/Functions/array/emptyArrayToSingle.cpp +++ b/src/Functions/array/emptyArrayToSingle.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -20,35 +20,6 @@ namespace ErrorCodes extern const int ILLEGAL_TYPE_OF_ARGUMENT; } - -/** emptyArrayToSingle(arr) - replace empty arrays with arrays of one element with a default value. - */ -class FunctionEmptyArrayToSingle : public IFunction -{ -public: - static constexpr auto name = "emptyArrayToSingle"; - static FunctionPtr create(ContextPtr) { return std::make_shared(); } - - String getName() const override { return name; } - - size_t getNumberOfArguments() const override { return 1; } - bool useDefaultImplementationForConstants() const override { return true; } - bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } - - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override - { - const DataTypeArray * array_type = checkAndGetDataType(arguments[0].get()); - if (!array_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument for function {} must be array.", getName()); - - return arguments[0]; - } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override; -}; - - namespace { namespace FunctionEmptyArrayToSingleImpl @@ -366,6 +337,14 @@ namespace } } +DataTypePtr FunctionEmptyArrayToSingle::getReturnTypeImpl(const DataTypes & arguments) const +{ + const DataTypeArray * array_type = checkAndGetDataType(arguments[0].get()); + if (!array_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument for function {} must be array.", getName()); + + return arguments[0]; +} ColumnPtr FunctionEmptyArrayToSingle::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { diff --git a/src/Functions/array/length.cpp b/src/Functions/array/length.cpp index 760506194fa..949a5441e58 100644 --- a/src/Functions/array/length.cpp +++ b/src/Functions/array/length.cpp @@ -1,65 +1,7 @@ -#include -#include -#include - +#include namespace DB { -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; -} - -/** Calculates the length of a string in bytes. - */ -struct LengthImpl -{ - static constexpr auto is_fixed_to_constant = true; - - static void vector(const ColumnString::Chars & /*data*/, const ColumnString::Offsets & offsets, PaddedPODArray & res, size_t input_rows_count) - { - for (size_t i = 0; i < input_rows_count; ++i) - res[i] = offsets[i] - 1 - offsets[i - 1]; - } - - static void vectorFixedToConstant(const ColumnString::Chars & /*data*/, size_t n, UInt64 & res, size_t) - { - res = n; - } - - static void vectorFixedToVector(const ColumnString::Chars & /*data*/, size_t /*n*/, PaddedPODArray & /*res*/, size_t) - { - } - - static void array(const ColumnString::Offsets & offsets, PaddedPODArray & res, size_t input_rows_count) - { - for (size_t i = 0; i < input_rows_count; ++i) - res[i] = offsets[i] - offsets[i - 1]; - } - - [[noreturn]] static void uuid(const ColumnUUID::Container &, size_t &, PaddedPODArray &, size_t) - { - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to UUID argument"); - } - - [[noreturn]] static void ipv6(const ColumnIPv6::Container &, size_t &, PaddedPODArray &, size_t) - { - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to IPv6 argument"); - } - - [[noreturn]] static void ipv4(const ColumnIPv4::Container &, size_t &, PaddedPODArray &, size_t) - { - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to IPv4 argument"); - } -}; - - -struct NameLength -{ - static constexpr auto name = "length"; -}; - -using FunctionLength = FunctionStringOrArrayToT; REGISTER_FUNCTION(Length) { diff --git a/src/Interpreters/ArrayJoinAction.cpp b/src/Interpreters/ArrayJoinAction.cpp index df7a0b48057..802d38b0c03 100644 --- a/src/Interpreters/ArrayJoinAction.cpp +++ b/src/Interpreters/ArrayJoinAction.cpp @@ -6,6 +6,9 @@ #include #include #include +#include +#include +#include #include #include @@ -59,26 +62,27 @@ ColumnWithTypeAndName convertArrayJoinColumn(const ColumnWithTypeAndName & src_c return array_col; } -ArrayJoinAction::ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, ContextPtr context) +ArrayJoinAction::ArrayJoinAction(const NameSet & array_joined_columns_, bool is_left_, bool is_unaligned_, size_t max_block_size_) : columns(array_joined_columns_) - , is_left(array_join_is_left) - , is_unaligned(context->getSettingsRef().enable_unaligned_array_join) - , max_block_size(context->getSettingsRef().max_block_size) + , is_left(is_left_) + , is_unaligned(is_unaligned_) + , max_block_size(max_block_size_) + // , is_unaligned(context->getSettingsRef().enable_unaligned_array_join) + // , max_block_size(context->getSettingsRef().max_block_size) { if (columns.empty()) throw Exception(ErrorCodes::LOGICAL_ERROR, "No arrays to join"); if (is_unaligned) { - function_length = FunctionFactory::instance().get("length", context); - function_greatest = FunctionFactory::instance().get("greatest", context); - function_array_resize = FunctionFactory::instance().get("arrayResize", context); + function_length = std::make_unique(FunctionLength::createImpl()); + function_array_resize = std::make_unique(FunctionArrayResize::createImpl()); } else if (is_left) - function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context); + function_builder = std::make_unique(FunctionEmptyArrayToSingle::createImpl()); } -void ArrayJoinAction::prepare(ColumnsWithTypeAndName & sample) const +void ArrayJoinAction::prepare(const NameSet & columns, ColumnsWithTypeAndName & sample) { for (auto & current : sample) { @@ -103,6 +107,35 @@ ArrayJoinResultIteratorPtr ArrayJoinAction::execute(Block block) return std::make_unique(this, std::move(block)); } +static void updateMaxLength(ColumnUInt64 & max_length, UInt64 length) +{ + for (auto & value : max_length.getData()) + value = std::max(value, length); +} + +static void updateMaxLength(ColumnUInt64 & max_length, const IColumn & length) +{ + if (const auto * length_const = typeid_cast(&length)) + { + updateMaxLength(max_length, length_const->getUInt(0)); + return; + } + + const auto * length_uint64 = typeid_cast(&length); + if (!length_uint64) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected UInt64 for array length, got {}", length.getName()); + + auto & max_lenght_data = max_length.getData(); + const auto & length_data = length_uint64->getData(); + size_t num_rows = max_lenght_data.size(); + if (num_rows != length_data.size()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Different columns sizes in ARRAY JOIN: {} and {}", num_rows, length_data.size()); + + for (size_t row = 0; row < num_rows; ++row) + max_lenght_data[row] = std::max(max_lenght_data[row], length_data[row]); +} ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_join_, Block block_) : array_join(array_join_), block(std::move(block_)), total_rows(block.rows()), current_row(0) @@ -111,7 +144,6 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j bool is_unaligned = array_join->is_unaligned; bool is_left = array_join->is_left; const auto & function_length = array_join->function_length; - const auto & function_greatest = array_join->function_greatest; const auto & function_array_resize = array_join->function_array_resize; const auto & function_builder = array_join->function_builder; @@ -125,11 +157,7 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j /// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values. auto rows = block.rows(); auto uint64 = std::make_shared(); - ColumnWithTypeAndName column_of_max_length{{}, uint64, {}}; - if (is_left) - column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {}); - else - column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {}); + auto max_length = ColumnUInt64::create(rows, (is_left ? 1u : 0u)); for (const auto & name : columns) { @@ -138,11 +166,10 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col); ColumnsWithTypeAndName tmp_block{array_col}; //, {{}, uint64, {}}}; auto len_col = function_length->build(tmp_block)->execute(tmp_block, uint64, rows); - - ColumnsWithTypeAndName tmp_block2{column_of_max_length, {len_col, uint64, {}}}; - column_of_max_length.column = function_greatest->build(tmp_block2)->execute(tmp_block2, uint64, rows); + updateMaxLength(*max_length, *len_col); } + ColumnWithTypeAndName column_of_max_length{std::move(max_length), uint64, {}}; for (const auto & name : columns) { auto & src_col = block.getByName(name); diff --git a/src/Interpreters/ArrayJoinAction.h b/src/Interpreters/ArrayJoinAction.h index 603f22ef245..287eabaac65 100644 --- a/src/Interpreters/ArrayJoinAction.h +++ b/src/Interpreters/ArrayJoinAction.h @@ -33,14 +33,14 @@ public: /// For unaligned [LEFT] ARRAY JOIN FunctionOverloadResolverPtr function_length; - FunctionOverloadResolverPtr function_greatest; + //FunctionOverloadResolverPtr function_greatest; FunctionOverloadResolverPtr function_array_resize; /// For LEFT ARRAY JOIN. FunctionOverloadResolverPtr function_builder; - ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, ContextPtr context); - void prepare(ColumnsWithTypeAndName & sample) const; + ArrayJoinAction(const NameSet & array_joined_columns_, bool is_left_, bool is_unaligned_, size_t max_block_size_); + static void prepare(const NameSet & columns, ColumnsWithTypeAndName & sample); ArrayJoinResultIteratorPtr execute(Block block); }; diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 7063b2162a0..3315f4a67b2 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -215,7 +215,7 @@ NamesAndTypesList ExpressionAnalyzer::getColumnsAfterArrayJoin(ActionsDAG & acti auto array_join = addMultipleArrayJoinAction(actions, is_array_join_left); auto sample_columns = actions.getResultColumns(); - array_join->prepare(sample_columns); + ArrayJoinAction::prepare(array_join->columns, sample_columns); actions = ActionsDAG(sample_columns); NamesAndTypesList new_columns_after_array_join; @@ -905,7 +905,8 @@ ArrayJoinActionPtr ExpressionAnalyzer::addMultipleArrayJoinAction(ActionsDAG & a result_columns.insert(result_source.first); } - return std::make_shared(result_columns, array_join_is_left, getContext()); + const auto & query_settings = getContext()->getSettingsRef(); + return std::make_shared(result_columns, array_join_is_left, query_settings.enable_unaligned_array_join, query_settings.max_block_size); } ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, ActionsAndProjectInputsFlagPtr & before_array_join, bool only_types) diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 0c79f4310ce..34e7d7422fd 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -86,6 +86,7 @@ #include #include #include +#include #include #include #include @@ -1655,7 +1656,12 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, std::optional

(query_plan.getCurrentDataStream(), expressions.array_join); + = std::make_unique( + query_plan.getCurrentDataStream(), + expressions.array_join->columns, + expressions.array_join->is_left, + expressions.array_join->is_unaligned, + expressions.array_join->max_block_size); array_join_step->setStepDescription("ARRAY JOIN"); query_plan.addStep(std::move(array_join_step)); diff --git a/src/Planner/PlannerJoinTree.cpp b/src/Planner/PlannerJoinTree.cpp index 212b57588ae..b98a17a4ce7 100644 --- a/src/Planner/PlannerJoinTree.cpp +++ b/src/Planner/PlannerJoinTree.cpp @@ -1720,8 +1720,14 @@ JoinTreeQueryPlan buildQueryPlanForArrayJoinNode(const QueryTreeNodePtr & array_ drop_unused_columns_before_array_join_transform_step->setStepDescription("DROP unused columns before ARRAY JOIN"); plan.addStep(std::move(drop_unused_columns_before_array_join_transform_step)); - auto array_join_action = std::make_shared(array_join_column_names, array_join_node.isLeft(), planner_context->getQueryContext()); - auto array_join_step = std::make_unique(plan.getCurrentDataStream(), std::move(array_join_action)); + const auto & settings = planner_context->getQueryContext()->getSettingsRef(); + auto array_join_step = std::make_unique( + plan.getCurrentDataStream(), + std::move(array_join_column_names), + array_join_node.isLeft(), + settings.enable_unaligned_array_join, + settings.max_block_size); + array_join_step->setStepDescription("ARRAY JOIN"); plan.addStep(std::move(array_join_step)); diff --git a/src/Processors/QueryPlan/ArrayJoinStep.cpp b/src/Processors/QueryPlan/ArrayJoinStep.cpp index 23a0a756f0d..aa721e138cf 100644 --- a/src/Processors/QueryPlan/ArrayJoinStep.cpp +++ b/src/Processors/QueryPlan/ArrayJoinStep.cpp @@ -24,23 +24,27 @@ static ITransformingStep::Traits getTraits() }; } -ArrayJoinStep::ArrayJoinStep(const DataStream & input_stream_, ArrayJoinActionPtr array_join_) +ArrayJoinStep::ArrayJoinStep(const DataStream & input_stream_, NameSet columns_, bool is_left_, bool is_unaligned_, size_t max_block_size_) : ITransformingStep( input_stream_, - ArrayJoinTransform::transformHeader(input_stream_.header, array_join_), + ArrayJoinTransform::transformHeader(input_stream_.header, columns_), getTraits()) - , array_join(std::move(array_join_)) + , columns(std::move(columns_)) + , is_left(is_left_) + , is_unaligned(is_unaligned_) + , max_block_size(max_block_size_) { } void ArrayJoinStep::updateOutputStream() { output_stream = createOutputStream( - input_streams.front(), ArrayJoinTransform::transformHeader(input_streams.front().header, array_join), getDataStreamTraits()); + input_streams.front(), ArrayJoinTransform::transformHeader(input_streams.front().header, columns), getDataStreamTraits()); } void ArrayJoinStep::transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) { + auto array_join = std::make_shared(columns, is_left, is_unaligned, max_block_size); pipeline.addSimpleTransform([&](const Block & header, QueryPipelineBuilder::StreamType stream_type) { bool on_totals = stream_type == QueryPipelineBuilder::StreamType::Totals; @@ -53,8 +57,8 @@ void ArrayJoinStep::describeActions(FormatSettings & settings) const String prefix(settings.offset, ' '); bool first = true; - settings.out << prefix << (array_join->is_left ? "LEFT " : "") << "ARRAY JOIN "; - for (const auto & column : array_join->columns) + settings.out << prefix << (is_left ? "LEFT " : "") << "ARRAY JOIN "; + for (const auto & column : columns) { if (!first) settings.out << ", "; @@ -68,10 +72,10 @@ void ArrayJoinStep::describeActions(FormatSettings & settings) const void ArrayJoinStep::describeActions(JSONBuilder::JSONMap & map) const { - map.add("Left", array_join->is_left); + map.add("Left", is_left); auto columns_array = std::make_unique(); - for (const auto & column : array_join->columns) + for (const auto & column : columns) columns_array->add(column); map.add("Columns", std::move(columns_array)); diff --git a/src/Processors/QueryPlan/ArrayJoinStep.h b/src/Processors/QueryPlan/ArrayJoinStep.h index 2d9b2ebd0c8..3f2eacc3159 100644 --- a/src/Processors/QueryPlan/ArrayJoinStep.h +++ b/src/Processors/QueryPlan/ArrayJoinStep.h @@ -10,7 +10,7 @@ using ArrayJoinActionPtr = std::shared_ptr; class ArrayJoinStep : public ITransformingStep { public: - explicit ArrayJoinStep(const DataStream & input_stream_, ArrayJoinActionPtr array_join_); + ArrayJoinStep(const DataStream & input_stream_, NameSet columns_, bool is_left_, bool is_unaligned_, size_t max_block_size_); String getName() const override { return "ArrayJoin"; } void transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) override; @@ -18,12 +18,16 @@ public: void describeActions(JSONBuilder::JSONMap & map) const override; void describeActions(FormatSettings & settings) const override; - const ArrayJoinActionPtr & arrayJoin() const { return array_join; } + const NameSet & getColumns() const { return columns; } + bool isLeft() const { return is_left; } private: void updateOutputStream() override; - ArrayJoinActionPtr array_join; + NameSet columns; + bool is_left = false; + bool is_unaligned = false; + size_t max_block_size = DEFAULT_BLOCK_SIZE; }; } diff --git a/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp b/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp index b71326ff75b..e4a292394f3 100644 --- a/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp +++ b/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp @@ -520,8 +520,7 @@ size_t tryPushDownFilter(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes if (auto * array_join = typeid_cast(child.get())) { - const auto & array_join_actions = array_join->arrayJoin(); - const auto & keys = array_join_actions->columns; + const auto & keys = array_join->getColumns(); const auto & array_join_header = array_join->getInputStreams().front().header; Names allowed_inputs; diff --git a/src/Processors/QueryPlan/Optimizations/liftUpArrayJoin.cpp b/src/Processors/QueryPlan/Optimizations/liftUpArrayJoin.cpp index 0d4f2330119..8866bb99cbe 100644 --- a/src/Processors/QueryPlan/Optimizations/liftUpArrayJoin.cpp +++ b/src/Processors/QueryPlan/Optimizations/liftUpArrayJoin.cpp @@ -24,11 +24,11 @@ size_t tryLiftUpArrayJoin(QueryPlan::Node * parent_node, QueryPlan::Nodes & node if (!(expression_step || filter_step) || !array_join_step) return 0; - const auto & array_join = array_join_step->arrayJoin(); + const auto & array_join_columns = array_join_step->getColumns(); const auto & expression = expression_step ? expression_step->getExpression() : filter_step->getExpression(); - auto split_actions = expression.splitActionsBeforeArrayJoin(array_join->columns); + auto split_actions = expression.splitActionsBeforeArrayJoin(array_join_columns); /// No actions can be moved before ARRAY JOIN. if (split_actions.first.trivial()) diff --git a/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp b/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp index 5df7d7b4e82..f517a5cd10d 100644 --- a/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp +++ b/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp @@ -230,10 +230,10 @@ void buildSortingDAG(QueryPlan::Node & node, std::optional & dag, Fi { /// Should ignore limit because ARRAY JOIN can reduce the number of rows in case of empty array. /// But in case of LEFT ARRAY JOIN the result number of rows is always bigger. - if (!array_join->arrayJoin()->is_left) + if (!array_join->isLeft()) limit = 0; - const auto & array_joined_columns = array_join->arrayJoin()->columns; + const auto & array_joined_columns = array_join->getColumns(); if (dag) { diff --git a/src/Processors/Transforms/ArrayJoinTransform.cpp b/src/Processors/Transforms/ArrayJoinTransform.cpp index 1304434d74e..bd436cbe408 100644 --- a/src/Processors/Transforms/ArrayJoinTransform.cpp +++ b/src/Processors/Transforms/ArrayJoinTransform.cpp @@ -10,10 +10,10 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -Block ArrayJoinTransform::transformHeader(Block header, const ArrayJoinActionPtr & array_join) +Block ArrayJoinTransform::transformHeader(Block header, const NameSet & array_join_columns) { auto columns = header.getColumnsWithTypeAndName(); - array_join->prepare(columns); + ArrayJoinAction::prepare(array_join_columns, columns); Block res{std::move(columns)}; res.setColumns(res.mutateColumns()); return res; @@ -23,7 +23,7 @@ ArrayJoinTransform::ArrayJoinTransform( const Block & header_, ArrayJoinActionPtr array_join_, bool /*on_totals_*/) - : IInflatingTransform(header_, transformHeader(header_, array_join_)) + : IInflatingTransform(header_, transformHeader(header_, array_join_->columns)) , array_join(std::move(array_join_)) { /// TODO diff --git a/src/Processors/Transforms/ArrayJoinTransform.h b/src/Processors/Transforms/ArrayJoinTransform.h index 4219135982d..386b9d6616b 100644 --- a/src/Processors/Transforms/ArrayJoinTransform.h +++ b/src/Processors/Transforms/ArrayJoinTransform.h @@ -22,7 +22,7 @@ public: String getName() const override { return "ArrayJoinTransform"; } - static Block transformHeader(Block header, const ArrayJoinActionPtr & array_join); + static Block transformHeader(Block header, const NameSet & array_join_columns); protected: void consume(Chunk chunk) override;