Merge pull request #69298 from ClickHouse/array-join-step-refactoring

Refactor ArrayJoin step.
This commit is contained in:
Nikolai Kochetov 2024-09-17 08:26:09 +00:00 committed by GitHub
commit e7eaa01bb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 360 additions and 261 deletions

View File

@ -27,7 +27,8 @@ class FunctionStringOrArrayToT : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr)
static FunctionPtr create(ContextPtr) { return createImpl(); }
static FunctionPtr createImpl()
{
return std::make_shared<FunctionStringOrArrayToT>();
}

View File

@ -1,4 +1,4 @@
#include <Functions/IFunction.h>
#include <Functions/array/arrayResize.h>
#include <Functions/FunctionFactory.h>
#include <Functions/GatherUtils/GatherUtils.h>
#include <DataTypes/DataTypeArray.h>
@ -21,117 +21,99 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
class FunctionArrayResize : public IFunction
DataTypePtr FunctionArrayResize::getReturnTypeImpl(const DataTypes & arguments) const
{
public:
static constexpr auto name = "arrayResize";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayResize>(); }
const size_t number_of_arguments = arguments.size();
String getName() const override { return name; }
if (number_of_arguments < 2 || number_of_arguments > 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
getName(), number_of_arguments);
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
if (arguments[0]->onlyNull())
return arguments[0];
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
const auto * array_type = typeid_cast<const DataTypeArray *>(arguments[0].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function {} must be an array but it has type {}.",
getName(), arguments[0]->getName());
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
if (WhichDataType(array_type->getNestedType()).isNothing())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} cannot resize {}", getName(), array_type->getName());
if (!isInteger(removeNullable(arguments[1])) && !arguments[1]->onlyNull())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument {} for function {} must be integer but it has type {}.",
toString(1), getName(), arguments[1]->getName());
if (number_of_arguments == 2)
return arguments[0];
else /* if (number_of_arguments == 3) */
return std::make_shared<DataTypeArray>(getLeastSupertype(DataTypes{array_type->getNestedType(), arguments[2]}));
}
ColumnPtr FunctionArrayResize::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t input_rows_count) const
{
if (return_type->onlyNull())
return return_type->createColumnConstWithDefaultValue(input_rows_count);
auto result_column = return_type->createColumn();
auto array_column = arguments[0].column;
auto size_column = arguments[1].column;
if (!arguments[0].type->equals(*return_type))
array_column = castColumn(arguments[0], return_type);
const DataTypePtr & return_nested_type = typeid_cast<const DataTypeArray &>(*return_type).getNestedType();
size_t size = array_column->size();
ColumnPtr appended_column;
if (arguments.size() == 3)
{
const size_t number_of_arguments = arguments.size();
appended_column = arguments[2].column;
if (!arguments[2].type->equals(*return_nested_type))
appended_column = castColumn(arguments[2], return_nested_type);
}
else
appended_column = return_nested_type->createColumnConstWithDefaultValue(size);
if (number_of_arguments < 2 || number_of_arguments > 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
getName(), number_of_arguments);
std::unique_ptr<GatherUtils::IArraySource> array_source;
std::unique_ptr<GatherUtils::IValueSource> value_source;
if (arguments[0]->onlyNull())
return arguments[0];
bool is_const = false;
const auto * array_type = typeid_cast<const DataTypeArray *>(arguments[0].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function {} must be an array but it has type {}.",
getName(), arguments[0]->getName());
if (WhichDataType(array_type->getNestedType()).isNothing())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} cannot resize {}", getName(), array_type->getName());
if (!isInteger(removeNullable(arguments[1])) && !arguments[1]->onlyNull())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument {} for function {} must be integer but it has type {}.",
toString(1), getName(), arguments[1]->getName());
if (number_of_arguments == 2)
return arguments[0];
else /* if (number_of_arguments == 3) */
return std::make_shared<DataTypeArray>(getLeastSupertype(DataTypes{array_type->getNestedType(), arguments[2]}));
if (const auto * const_array_column = typeid_cast<const ColumnConst *>(array_column.get()))
{
is_const = true;
array_column = const_array_column->getDataColumnPtr();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t input_rows_count) const override
if (const auto * argument_column_array = typeid_cast<const ColumnArray *>(array_column.get()))
array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "First arguments for function {} must be array.", getName());
bool is_appended_const = false;
if (const auto * const_appended_column = typeid_cast<const ColumnConst *>(appended_column.get()))
{
if (return_type->onlyNull())
return return_type->createColumnConstWithDefaultValue(input_rows_count);
auto result_column = return_type->createColumn();
auto array_column = arguments[0].column;
auto size_column = arguments[1].column;
if (!arguments[0].type->equals(*return_type))
array_column = castColumn(arguments[0], return_type);
const DataTypePtr & return_nested_type = typeid_cast<const DataTypeArray &>(*return_type).getNestedType();
size_t size = array_column->size();
ColumnPtr appended_column;
if (arguments.size() == 3)
{
appended_column = arguments[2].column;
if (!arguments[2].type->equals(*return_nested_type))
appended_column = castColumn(arguments[2], return_nested_type);
}
else
appended_column = return_nested_type->createColumnConstWithDefaultValue(size);
std::unique_ptr<GatherUtils::IArraySource> array_source;
std::unique_ptr<GatherUtils::IValueSource> value_source;
bool is_const = false;
if (const auto * const_array_column = typeid_cast<const ColumnConst *>(array_column.get()))
{
is_const = true;
array_column = const_array_column->getDataColumnPtr();
}
if (const auto * argument_column_array = typeid_cast<const ColumnArray *>(array_column.get()))
array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "First arguments for function {} must be array.", getName());
bool is_appended_const = false;
if (const auto * const_appended_column = typeid_cast<const ColumnConst *>(appended_column.get()))
{
is_appended_const = true;
appended_column = const_appended_column->getDataColumnPtr();
}
value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size);
auto sink = GatherUtils::createArraySink(typeid_cast<ColumnArray &>(*result_column), size);
if (isColumnConst(*size_column))
GatherUtils::resizeConstantSize(*array_source, *value_source, *sink, size_column->getInt(0));
else
GatherUtils::resizeDynamicSize(*array_source, *value_source, *sink, *size_column);
return result_column;
is_appended_const = true;
appended_column = const_appended_column->getDataColumnPtr();
}
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForNulls() const override { return false; }
};
value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size);
auto sink = GatherUtils::createArraySink(typeid_cast<ColumnArray &>(*result_column), size);
if (isColumnConst(*size_column))
GatherUtils::resizeConstantSize(*array_source, *value_source, *sink, size_column->getInt(0));
else
GatherUtils::resizeDynamicSize(*array_source, *value_source, *sink, *size_column);
return result_column;
}
REGISTER_FUNCTION(ArrayResize)
{

View File

@ -0,0 +1,28 @@
#pragma once
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
class FunctionArrayResize : public IFunction
{
public:
static constexpr auto name = "arrayResize";
static FunctionPtr createImpl() { return std::make_shared<FunctionArrayResize>(); }
static FunctionPtr create(ContextPtr) { return createImpl(); }
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t input_rows_count) const override;
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForNulls() const override { return false; }
};
}

View File

@ -1,4 +1,4 @@
#include <Functions/IFunction.h>
#include <Functions/array/emptyArrayToSingle.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
@ -20,35 +20,6 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/** emptyArrayToSingle(arr) - replace empty arrays with arrays of one element with a default value.
*/
class FunctionEmptyArrayToSingle : public IFunction
{
public:
static constexpr auto name = "emptyArrayToSingle";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionEmptyArrayToSingle>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument for function {} must be array.", getName());
return arguments[0];
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
};
namespace
{
namespace FunctionEmptyArrayToSingleImpl
@ -366,6 +337,14 @@ namespace
}
}
DataTypePtr FunctionEmptyArrayToSingle::getReturnTypeImpl(const DataTypes & arguments) const
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument for function {} must be array.", getName());
return arguments[0];
}
ColumnPtr FunctionEmptyArrayToSingle::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{

View File

@ -0,0 +1,29 @@
#pragma once
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
/** emptyArrayToSingle(arr) - replace empty arrays with arrays of one element with a default value.
*/
class FunctionEmptyArrayToSingle : public IFunction
{
public:
static constexpr auto name = "emptyArrayToSingle";
static FunctionPtr createImpl() { return std::make_shared<FunctionEmptyArrayToSingle>(); }
static FunctionPtr create(ContextPtr) { return createImpl(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
};
}

View File

@ -1,65 +1,7 @@
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionStringOrArrayToT.h>
#include <Functions/array/length.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/** Calculates the length of a string in bytes.
*/
struct LengthImpl
{
static constexpr auto is_fixed_to_constant = true;
static void vector(const ColumnString::Chars & /*data*/, const ColumnString::Offsets & offsets, PaddedPODArray<UInt64> & res, size_t input_rows_count)
{
for (size_t i = 0; i < input_rows_count; ++i)
res[i] = offsets[i] - 1 - offsets[i - 1];
}
static void vectorFixedToConstant(const ColumnString::Chars & /*data*/, size_t n, UInt64 & res, size_t)
{
res = n;
}
static void vectorFixedToVector(const ColumnString::Chars & /*data*/, size_t /*n*/, PaddedPODArray<UInt64> & /*res*/, size_t)
{
}
static void array(const ColumnString::Offsets & offsets, PaddedPODArray<UInt64> & res, size_t input_rows_count)
{
for (size_t i = 0; i < input_rows_count; ++i)
res[i] = offsets[i] - offsets[i - 1];
}
[[noreturn]] static void uuid(const ColumnUUID::Container &, size_t &, PaddedPODArray<UInt64> &, size_t)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to UUID argument");
}
[[noreturn]] static void ipv6(const ColumnIPv6::Container &, size_t &, PaddedPODArray<UInt64> &, size_t)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to IPv6 argument");
}
[[noreturn]] static void ipv4(const ColumnIPv4::Container &, size_t &, PaddedPODArray<UInt64> &, size_t)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to IPv4 argument");
}
};
struct NameLength
{
static constexpr auto name = "length";
};
using FunctionLength = FunctionStringOrArrayToT<LengthImpl, NameLength, UInt64, false>;
REGISTER_FUNCTION(Length)
{

View File

@ -0,0 +1,66 @@
#pragma once
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionStringOrArrayToT.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/** Calculates the length of a string in bytes.
*/
struct LengthImpl
{
static constexpr auto is_fixed_to_constant = true;
static void vector(const ColumnString::Chars & /*data*/, const ColumnString::Offsets & offsets, PaddedPODArray<UInt64> & res, size_t input_rows_count)
{
for (size_t i = 0; i < input_rows_count; ++i)
res[i] = offsets[i] - 1 - offsets[i - 1];
}
static void vectorFixedToConstant(const ColumnString::Chars & /*data*/, size_t n, UInt64 & res, size_t)
{
res = n;
}
static void vectorFixedToVector(const ColumnString::Chars & /*data*/, size_t /*n*/, PaddedPODArray<UInt64> & /*res*/, size_t)
{
}
static void array(const ColumnString::Offsets & offsets, PaddedPODArray<UInt64> & res, size_t input_rows_count)
{
for (size_t i = 0; i < input_rows_count; ++i)
res[i] = offsets[i] - offsets[i - 1];
}
[[noreturn]] static void uuid(const ColumnUUID::Container &, size_t &, PaddedPODArray<UInt64> &, size_t)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to UUID argument");
}
[[noreturn]] static void ipv6(const ColumnIPv6::Container &, size_t &, PaddedPODArray<UInt64> &, size_t)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to IPv6 argument");
}
[[noreturn]] static void ipv4(const ColumnIPv4::Container &, size_t &, PaddedPODArray<UInt64> &, size_t)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot apply function length to IPv4 argument");
}
};
struct NameLength
{
static constexpr auto name = "length";
};
using FunctionLength = FunctionStringOrArrayToT<LengthImpl, NameLength, UInt64, false>;
}

View File

@ -2028,8 +2028,9 @@ ActionsDAG::SplitResult ActionsDAG::split(std::unordered_set<const Node *> split
return {std::move(first_actions), std::move(second_actions), std::move(split_nodes_mapping)};
}
ActionsDAG::SplitResult ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_joined_columns) const
ActionsDAG::SplitResult ActionsDAG::splitActionsBeforeArrayJoin(const Names & array_joined_columns) const
{
std::unordered_set<std::string_view> array_joined_columns_set(array_joined_columns.begin(), array_joined_columns.end());
struct Frame
{
const Node * node = nullptr;
@ -2072,7 +2073,7 @@ ActionsDAG::SplitResult ActionsDAG::splitActionsBeforeArrayJoin(const NameSet &
if (cur.next_child_to_visit == cur.node->children.size())
{
bool depend_on_array_join = false;
if (cur.node->type == ActionType::INPUT && array_joined_columns.contains(cur.node->result_name))
if (cur.node->type == ActionType::INPUT && array_joined_columns_set.contains(cur.node->result_name))
depend_on_array_join = true;
for (const auto * child : cur.node->children)

View File

@ -340,7 +340,7 @@ public:
SplitResult split(std::unordered_set<const Node *> split_nodes, bool create_split_nodes_mapping = false, bool avoid_duplicate_inputs = false) const;
/// Splits actions into two parts. Returned first half may be swapped with ARRAY JOIN.
SplitResult splitActionsBeforeArrayJoin(const NameSet & array_joined_columns) const;
SplitResult splitActionsBeforeArrayJoin(const Names & array_joined_columns) const;
/// Splits actions into two parts. First part has minimal size sufficient for calculation of column_name.
/// Outputs of initial actions must contain column_name.

View File

@ -0,0 +1,13 @@
#pragma once
#include <Core/Names.h>
namespace DB
{
struct ArrayJoin
{
Names columns;
bool is_left = false;
};
}

View File

@ -6,6 +6,9 @@
#include <Columns/ColumnMap.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/array/length.h>
#include <Functions/array/arrayResize.h>
#include <Functions/array/emptyArrayToSingle.h>
#include <Interpreters/Context.h>
#include <Interpreters/ArrayJoinAction.h>
@ -59,26 +62,31 @@ ColumnWithTypeAndName convertArrayJoinColumn(const ColumnWithTypeAndName & src_c
return array_col;
}
ArrayJoinAction::ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, ContextPtr context)
: columns(array_joined_columns_)
, is_left(array_join_is_left)
, is_unaligned(context->getSettingsRef().enable_unaligned_array_join)
, max_block_size(context->getSettingsRef().max_block_size)
ArrayJoinAction::ArrayJoinAction(const Names & columns_, bool is_left_, bool is_unaligned_, size_t max_block_size_)
: columns(columns_.begin(), columns_.end())
, is_left(is_left_)
, is_unaligned(is_unaligned_)
, max_block_size(max_block_size_)
{
if (columns.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "No arrays to join");
if (is_unaligned)
{
function_length = FunctionFactory::instance().get("length", context);
function_greatest = FunctionFactory::instance().get("greatest", context);
function_array_resize = FunctionFactory::instance().get("arrayResize", context);
function_length = std::make_unique<FunctionToOverloadResolverAdaptor>(FunctionLength::createImpl());
function_array_resize = std::make_unique<FunctionToOverloadResolverAdaptor>(FunctionArrayResize::createImpl());
}
else if (is_left)
function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
function_builder = std::make_unique<FunctionToOverloadResolverAdaptor>(FunctionEmptyArrayToSingle::createImpl());
}
void ArrayJoinAction::prepare(ColumnsWithTypeAndName & sample) const
void ArrayJoinAction::prepare(const Names & columns, ColumnsWithTypeAndName & sample)
{
NameSet columns_set(columns.begin(), columns.end());
prepare(columns_set, sample);
}
void ArrayJoinAction::prepare(const NameSet & columns, ColumnsWithTypeAndName & sample)
{
for (auto & current : sample)
{
@ -103,6 +111,35 @@ ArrayJoinResultIteratorPtr ArrayJoinAction::execute(Block block)
return std::make_unique<ArrayJoinResultIterator>(this, std::move(block));
}
static void updateMaxLength(ColumnUInt64 & max_length, UInt64 length)
{
for (auto & value : max_length.getData())
value = std::max(value, length);
}
static void updateMaxLength(ColumnUInt64 & max_length, const IColumn & length)
{
if (const auto * length_const = typeid_cast<const ColumnConst *>(&length))
{
updateMaxLength(max_length, length_const->getUInt(0));
return;
}
const auto * length_uint64 = typeid_cast<const ColumnUInt64 *>(&length);
if (!length_uint64)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected UInt64 for array length, got {}", length.getName());
auto & max_lenght_data = max_length.getData();
const auto & length_data = length_uint64->getData();
size_t num_rows = max_lenght_data.size();
if (num_rows != length_data.size())
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Different columns sizes in ARRAY JOIN: {} and {}", num_rows, length_data.size());
for (size_t row = 0; row < num_rows; ++row)
max_lenght_data[row] = std::max(max_lenght_data[row], length_data[row]);
}
ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_join_, Block block_)
: array_join(array_join_), block(std::move(block_)), total_rows(block.rows()), current_row(0)
@ -111,7 +148,6 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j
bool is_unaligned = array_join->is_unaligned;
bool is_left = array_join->is_left;
const auto & function_length = array_join->function_length;
const auto & function_greatest = array_join->function_greatest;
const auto & function_array_resize = array_join->function_array_resize;
const auto & function_builder = array_join->function_builder;
@ -125,11 +161,7 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j
/// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values.
auto rows = block.rows();
auto uint64 = std::make_shared<DataTypeUInt64>();
ColumnWithTypeAndName column_of_max_length{{}, uint64, {}};
if (is_left)
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {});
else
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {});
auto max_length = ColumnUInt64::create(rows, (is_left ? 1u : 0u));
for (const auto & name : columns)
{
@ -138,11 +170,10 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j
ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col);
ColumnsWithTypeAndName tmp_block{array_col}; //, {{}, uint64, {}}};
auto len_col = function_length->build(tmp_block)->execute(tmp_block, uint64, rows);
ColumnsWithTypeAndName tmp_block2{column_of_max_length, {len_col, uint64, {}}};
column_of_max_length.column = function_greatest->build(tmp_block2)->execute(tmp_block2, uint64, rows);
updateMaxLength(*max_length, *len_col);
}
ColumnWithTypeAndName column_of_max_length{std::move(max_length), uint64, {}};
for (const auto & name : columns)
{
auto & src_col = block.getByName(name);

View File

@ -33,14 +33,14 @@ public:
/// For unaligned [LEFT] ARRAY JOIN
FunctionOverloadResolverPtr function_length;
FunctionOverloadResolverPtr function_greatest;
FunctionOverloadResolverPtr function_array_resize;
/// For LEFT ARRAY JOIN.
FunctionOverloadResolverPtr function_builder;
ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, ContextPtr context);
void prepare(ColumnsWithTypeAndName & sample) const;
ArrayJoinAction(const Names & columns_, bool is_left_, bool is_unaligned_, size_t max_block_size_);
static void prepare(const NameSet & columns, ColumnsWithTypeAndName & sample);
static void prepare(const Names & columns, ColumnsWithTypeAndName & sample);
ArrayJoinResultIteratorPtr execute(Block block);
};

View File

@ -1059,16 +1059,16 @@ std::string ExpressionActionsChain::dumpChain() const
return ss.str();
}
ExpressionActionsChain::ArrayJoinStep::ArrayJoinStep(ArrayJoinActionPtr array_join_, ColumnsWithTypeAndName required_columns_)
ExpressionActionsChain::ArrayJoinStep::ArrayJoinStep(const Names & array_join_columns_, ColumnsWithTypeAndName required_columns_)
: Step({})
, array_join(std::move(array_join_))
, array_join_columns(array_join_columns_.begin(), array_join_columns_.end())
, result_columns(std::move(required_columns_))
{
for (auto & column : result_columns)
{
required_columns.emplace_back(NameAndTypePair(column.name, column.type));
if (array_join->columns.contains(column.name))
if (array_join_columns.contains(column.name))
{
const auto & array = getArrayJoinDataType(column.type);
column.type = array->getNestedType();
@ -1085,12 +1085,12 @@ void ExpressionActionsChain::ArrayJoinStep::finalize(const NameSet & required_ou
for (const auto & column : result_columns)
{
if (array_join->columns.contains(column.name) || required_output_.contains(column.name))
if (array_join_columns.contains(column.name) || required_output_.contains(column.name))
new_result_columns.emplace_back(column);
}
for (const auto & column : required_columns)
{
if (array_join->columns.contains(column.name) || required_output_.contains(column.name))
if (array_join_columns.contains(column.name) || required_output_.contains(column.name))
new_required_columns.emplace_back(column);
}

View File

@ -3,6 +3,7 @@
#include <Core/Block.h>
#include <Core/ColumnNumbers.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/ArrayJoin.h>
#include <Interpreters/ExpressionActionsSettings.h>
#include <variant>
@ -22,9 +23,6 @@ class TableJoin;
class IJoin;
using JoinPtr = std::shared_ptr<IJoin>;
class ArrayJoinAction;
using ArrayJoinActionPtr = std::shared_ptr<ArrayJoinAction>;
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
@ -223,11 +221,11 @@ struct ExpressionActionsChain : WithContext
struct ArrayJoinStep : public Step
{
ArrayJoinActionPtr array_join;
const NameSet array_join_columns;
NamesAndTypesList required_columns;
ColumnsWithTypeAndName result_columns;
ArrayJoinStep(ArrayJoinActionPtr array_join_, ColumnsWithTypeAndName required_columns_);
ArrayJoinStep(const Names & array_join_columns_, ColumnsWithTypeAndName required_columns_);
NamesAndTypesList getRequiredColumns() const override { return required_columns; }
ColumnsWithTypeAndName getResultColumns() const override { return result_columns; }

View File

@ -215,7 +215,7 @@ NamesAndTypesList ExpressionAnalyzer::getColumnsAfterArrayJoin(ActionsDAG & acti
auto array_join = addMultipleArrayJoinAction(actions, is_array_join_left);
auto sample_columns = actions.getResultColumns();
array_join->prepare(sample_columns);
ArrayJoinAction::prepare(array_join.columns, sample_columns);
actions = ActionsDAG(sample_columns);
NamesAndTypesList new_columns_after_array_join;
@ -889,9 +889,11 @@ const ASTSelectQuery * SelectQueryExpressionAnalyzer::getAggregatingQuery() cons
}
/// "Big" ARRAY JOIN.
ArrayJoinActionPtr ExpressionAnalyzer::addMultipleArrayJoinAction(ActionsDAG & actions, bool array_join_is_left) const
ArrayJoin ExpressionAnalyzer::addMultipleArrayJoinAction(ActionsDAG & actions, bool array_join_is_left) const
{
NameSet result_columns;
Names result_columns;
result_columns.reserve(syntax->array_join_result_to_source.size());
for (const auto & result_source : syntax->array_join_result_to_source)
{
/// Assign new names to columns, if needed.
@ -902,19 +904,19 @@ ArrayJoinActionPtr ExpressionAnalyzer::addMultipleArrayJoinAction(ActionsDAG & a
}
/// Make ARRAY JOIN (replace arrays with their insides) for the columns in these new names.
result_columns.insert(result_source.first);
result_columns.push_back(result_source.first);
}
return std::make_shared<ArrayJoinAction>(result_columns, array_join_is_left, getContext());
return {std::move(result_columns), array_join_is_left};
}
ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, ActionsAndProjectInputsFlagPtr & before_array_join, bool only_types)
std::optional<ArrayJoin> SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, ActionsAndProjectInputsFlagPtr & before_array_join, bool only_types)
{
const auto * select_query = getSelectQuery();
auto [array_join_expression_list, is_array_join_left] = select_query->arrayJoinExpressionList();
if (!array_join_expression_list)
return nullptr;
return {};
ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
@ -923,7 +925,7 @@ ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActi
auto array_join = addMultipleArrayJoinAction(step.actions()->dag, is_array_join_left);
before_array_join = chain.getLastActions();
chain.steps.push_back(std::make_unique<ExpressionActionsChain::ArrayJoinStep>(array_join, step.getResultColumns()));
chain.steps.push_back(std::make_unique<ExpressionActionsChain::ArrayJoinStep>(array_join.columns, step.getResultColumns()));
chain.addStep();

View File

@ -174,7 +174,7 @@ protected:
/// Find global subqueries in the GLOBAL IN/JOIN sections. Fills in external_tables.
void initGlobalSubqueriesAndExternalTables(bool do_global, bool is_explain);
ArrayJoinActionPtr addMultipleArrayJoinAction(ActionsDAG & actions, bool is_left) const;
ArrayJoin addMultipleArrayJoinAction(ActionsDAG & actions, bool is_left) const;
void getRootActions(const ASTPtr & ast, bool no_makeset_for_subqueries, ActionsDAG & actions, bool only_consts = false);
@ -234,7 +234,7 @@ struct ExpressionAnalysisResult
bool use_grouping_set_key = false;
ActionsAndProjectInputsFlagPtr before_array_join;
ArrayJoinActionPtr array_join;
std::optional<ArrayJoin> array_join;
ActionsAndProjectInputsFlagPtr before_join;
ActionsAndProjectInputsFlagPtr converting_join_columns;
JoinPtr join;
@ -388,7 +388,7 @@ private:
*/
/// Before aggregation:
ArrayJoinActionPtr appendArrayJoin(ExpressionActionsChain & chain, ActionsAndProjectInputsFlagPtr & before_array_join, bool only_types);
std::optional<ArrayJoin> appendArrayJoin(ExpressionActionsChain & chain, ActionsAndProjectInputsFlagPtr & before_array_join, bool only_types);
bool appendJoinLeftKeys(ExpressionActionsChain & chain, bool only_types);
JoinPtr appendJoin(ExpressionActionsChain & chain, ActionsAndProjectInputsFlagPtr & converting_join_columns);

View File

@ -86,6 +86,7 @@
#include <Core/Settings.h>
#include <Core/ServerSettings.h>
#include <Interpreters/Aggregator.h>
#include <Interpreters/ArrayJoinAction.h>
#include <Interpreters/HashTablesStatistics.h>
#include <Interpreters/IJoin.h>
#include <QueryPipeline/SizeLimits.h>
@ -1676,7 +1677,11 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, std::optional<P
if (expressions.array_join)
{
QueryPlanStepPtr array_join_step
= std::make_unique<ArrayJoinStep>(query_plan.getCurrentDataStream(), expressions.array_join);
= std::make_unique<ArrayJoinStep>(
query_plan.getCurrentDataStream(),
*expressions.array_join,
settings.enable_unaligned_array_join,
settings.max_block_size);
array_join_step->setStepDescription("ARRAY JOIN");
query_plan.addStep(std::move(array_join_step));

View File

@ -1674,11 +1674,12 @@ JoinTreeQueryPlan buildQueryPlanForArrayJoinNode(const QueryTreeNodePtr & array_
PlannerActionsVisitor actions_visitor(planner_context);
std::unordered_set<std::string> array_join_expressions_output_nodes;
NameSet array_join_column_names;
Names array_join_column_names;
array_join_column_names.reserve(array_join_node.getJoinExpressions().getNodes().size());
for (auto & array_join_expression : array_join_node.getJoinExpressions().getNodes())
{
const auto & array_join_column_identifier = planner_context->getColumnNodeIdentifierOrThrow(array_join_expression);
array_join_column_names.insert(array_join_column_identifier);
array_join_column_names.push_back(array_join_column_identifier);
auto & array_join_expression_column = array_join_expression->as<ColumnNode &>();
auto expression_dag_index_nodes = actions_visitor.visit(array_join_action_dag, array_join_expression_column.getExpressionOrThrow());
@ -1727,8 +1728,13 @@ JoinTreeQueryPlan buildQueryPlanForArrayJoinNode(const QueryTreeNodePtr & array_
drop_unused_columns_before_array_join_transform_step->setStepDescription("DROP unused columns before ARRAY JOIN");
plan.addStep(std::move(drop_unused_columns_before_array_join_transform_step));
auto array_join_action = std::make_shared<ArrayJoinAction>(array_join_column_names, array_join_node.isLeft(), planner_context->getQueryContext());
auto array_join_step = std::make_unique<ArrayJoinStep>(plan.getCurrentDataStream(), std::move(array_join_action));
const auto & settings = planner_context->getQueryContext()->getSettingsRef();
auto array_join_step = std::make_unique<ArrayJoinStep>(
plan.getCurrentDataStream(),
ArrayJoin{std::move(array_join_column_names), array_join_node.isLeft()},
settings.enable_unaligned_array_join,
settings.max_block_size);
array_join_step->setStepDescription("ARRAY JOIN");
plan.addStep(std::move(array_join_step));

View File

@ -24,27 +24,30 @@ static ITransformingStep::Traits getTraits()
};
}
ArrayJoinStep::ArrayJoinStep(const DataStream & input_stream_, ArrayJoinActionPtr array_join_)
ArrayJoinStep::ArrayJoinStep(const DataStream & input_stream_, ArrayJoin array_join_, bool is_unaligned_, size_t max_block_size_)
: ITransformingStep(
input_stream_,
ArrayJoinTransform::transformHeader(input_stream_.header, array_join_),
ArrayJoinTransform::transformHeader(input_stream_.header, array_join_.columns),
getTraits())
, array_join(std::move(array_join_))
, is_unaligned(is_unaligned_)
, max_block_size(max_block_size_)
{
}
void ArrayJoinStep::updateOutputStream()
{
output_stream = createOutputStream(
input_streams.front(), ArrayJoinTransform::transformHeader(input_streams.front().header, array_join), getDataStreamTraits());
input_streams.front(), ArrayJoinTransform::transformHeader(input_streams.front().header, array_join.columns), getDataStreamTraits());
}
void ArrayJoinStep::transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &)
{
auto array_join_actions = std::make_shared<ArrayJoinAction>(array_join.columns, array_join.is_left, is_unaligned, max_block_size);
pipeline.addSimpleTransform([&](const Block & header, QueryPipelineBuilder::StreamType stream_type)
{
bool on_totals = stream_type == QueryPipelineBuilder::StreamType::Totals;
return std::make_shared<ArrayJoinTransform>(header, array_join, on_totals);
return std::make_shared<ArrayJoinTransform>(header, array_join_actions, on_totals);
});
}
@ -53,8 +56,8 @@ void ArrayJoinStep::describeActions(FormatSettings & settings) const
String prefix(settings.offset, ' ');
bool first = true;
settings.out << prefix << (array_join->is_left ? "LEFT " : "") << "ARRAY JOIN ";
for (const auto & column : array_join->columns)
settings.out << prefix << (array_join.is_left ? "LEFT " : "") << "ARRAY JOIN ";
for (const auto & column : array_join.columns)
{
if (!first)
settings.out << ", ";
@ -68,10 +71,10 @@ void ArrayJoinStep::describeActions(FormatSettings & settings) const
void ArrayJoinStep::describeActions(JSONBuilder::JSONMap & map) const
{
map.add("Left", array_join->is_left);
map.add("Left", array_join.is_left);
auto columns_array = std::make_unique<JSONBuilder::JSONArray>();
for (const auto & column : array_join->columns)
for (const auto & column : array_join.columns)
columns_array->add(column);
map.add("Columns", std::move(columns_array));

View File

@ -1,5 +1,6 @@
#pragma once
#include <Processors/QueryPlan/ITransformingStep.h>
#include <Interpreters/ArrayJoin.h>
namespace DB
{
@ -10,7 +11,7 @@ using ArrayJoinActionPtr = std::shared_ptr<ArrayJoinAction>;
class ArrayJoinStep : public ITransformingStep
{
public:
explicit ArrayJoinStep(const DataStream & input_stream_, ArrayJoinActionPtr array_join_);
ArrayJoinStep(const DataStream & input_stream_, ArrayJoin array_join_, bool is_unaligned_, size_t max_block_size_);
String getName() const override { return "ArrayJoin"; }
void transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) override;
@ -18,12 +19,15 @@ public:
void describeActions(JSONBuilder::JSONMap & map) const override;
void describeActions(FormatSettings & settings) const override;
const ArrayJoinActionPtr & arrayJoin() const { return array_join; }
const Names & getColumns() const { return array_join.columns; }
bool isLeft() const { return array_join.is_left; }
private:
void updateOutputStream() override;
ArrayJoinActionPtr array_join;
ArrayJoin array_join;
bool is_unaligned = false;
size_t max_block_size = DEFAULT_BLOCK_SIZE;
};
}

View File

@ -520,13 +520,14 @@ size_t tryPushDownFilter(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes
if (auto * array_join = typeid_cast<ArrayJoinStep *>(child.get()))
{
const auto & array_join_actions = array_join->arrayJoin();
const auto & keys = array_join_actions->columns;
const auto & keys = array_join->getColumns();
std::unordered_set<std::string_view> keys_set(keys.begin(), keys.end());
const auto & array_join_header = array_join->getInputStreams().front().header;
Names allowed_inputs;
for (const auto & column : array_join_header)
if (!keys.contains(column.name))
if (!keys_set.contains(column.name))
allowed_inputs.push_back(column.name);
if (auto updated_steps = tryAddNewFilterStep(parent_node, nodes, allowed_inputs))

View File

@ -24,11 +24,11 @@ size_t tryLiftUpArrayJoin(QueryPlan::Node * parent_node, QueryPlan::Nodes & node
if (!(expression_step || filter_step) || !array_join_step)
return 0;
const auto & array_join = array_join_step->arrayJoin();
const auto & array_join_columns = array_join_step->getColumns();
const auto & expression = expression_step ? expression_step->getExpression()
: filter_step->getExpression();
auto split_actions = expression.splitActionsBeforeArrayJoin(array_join->columns);
auto split_actions = expression.splitActionsBeforeArrayJoin(array_join_columns);
/// No actions can be moved before ARRAY JOIN.
if (split_actions.first.trivial())

View File

@ -231,13 +231,15 @@ void buildSortingDAG(QueryPlan::Node & node, std::optional<ActionsDAG> & dag, Fi
{
/// Should ignore limit because ARRAY JOIN can reduce the number of rows in case of empty array.
/// But in case of LEFT ARRAY JOIN the result number of rows is always bigger.
if (!array_join->arrayJoin()->is_left)
if (!array_join->isLeft())
limit = 0;
const auto & array_joined_columns = array_join->arrayJoin()->columns;
const auto & array_joined_columns = array_join->getColumns();
if (dag)
{
std::unordered_set<std::string_view> keys_set(array_joined_columns.begin(), array_joined_columns.end());
/// Remove array joined columns from outputs.
/// Types are changed after ARRAY JOIN, and we can't use this columns anyway.
ActionsDAG::NodeRawConstPtrs outputs;
@ -245,7 +247,7 @@ void buildSortingDAG(QueryPlan::Node & node, std::optional<ActionsDAG> & dag, Fi
for (const auto & output : dag->getOutputs())
{
if (!array_joined_columns.contains(output->result_name))
if (!keys_set.contains(output->result_name))
outputs.push_back(output);
}

View File

@ -10,20 +10,26 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
Block ArrayJoinTransform::transformHeader(Block header, const ArrayJoinActionPtr & array_join)
template <typename Container>
Block transformHeaderImpl(Block header, const Container & array_join_columns)
{
auto columns = header.getColumnsWithTypeAndName();
array_join->prepare(columns);
ArrayJoinAction::prepare(array_join_columns, columns);
Block res{std::move(columns)};
res.setColumns(res.mutateColumns());
return res;
}
Block ArrayJoinTransform::transformHeader(Block header, const Names & array_join_columns)
{
return transformHeaderImpl(std::move(header), array_join_columns);
}
ArrayJoinTransform::ArrayJoinTransform(
const Block & header_,
ArrayJoinActionPtr array_join_,
bool /*on_totals_*/)
: IInflatingTransform(header_, transformHeader(header_, array_join_))
: IInflatingTransform(header_, transformHeaderImpl(header_, array_join_->columns))
, array_join(std::move(array_join_))
{
/// TODO

View File

@ -22,7 +22,7 @@ public:
String getName() const override { return "ArrayJoinTransform"; }
static Block transformHeader(Block header, const ArrayJoinActionPtr & array_join);
static Block transformHeader(Block header, const Names & array_join_columns);
protected:
void consume(Chunk chunk) override;