From 230938d3a3082fbf241c9d873571231a69a5f450 Mon Sep 17 00:00:00 2001 From: Amos Bird Date: Sat, 11 Jul 2020 15:12:42 +0800 Subject: [PATCH] Refactor joinGet and implement multi-key lookup. --- src/Functions/FunctionJoinGet.cpp | 83 +++++++++---------- src/Functions/FunctionJoinGet.h | 11 +-- src/Interpreters/HashJoin.cpp | 69 ++++++++------- src/Interpreters/HashJoin.h | 10 +-- src/Interpreters/misc.h | 2 +- .../0_stateless/01080_join_get_null.reference | 2 +- .../0_stateless/01080_join_get_null.sql | 12 +-- .../01400_join_get_with_multi_keys.reference | 1 + .../01400_join_get_with_multi_keys.sql | 9 ++ 9 files changed, 104 insertions(+), 95 deletions(-) create mode 100644 tests/queries/0_stateless/01400_join_get_with_multi_keys.reference create mode 100644 tests/queries/0_stateless/01400_join_get_with_multi_keys.sql diff --git a/src/Functions/FunctionJoinGet.cpp b/src/Functions/FunctionJoinGet.cpp index a33b70684a5..1badc689c6a 100644 --- a/src/Functions/FunctionJoinGet.cpp +++ b/src/Functions/FunctionJoinGet.cpp @@ -1,10 +1,10 @@ #include +#include #include #include #include #include -#include #include @@ -16,19 +16,35 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } +template +void ExecutableFunctionJoinGet::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t) +{ + Block keys; + for (size_t i = 2; i < arguments.size(); ++i) + { + auto key = block.getByPosition(arguments[i]); + keys.insert(std::move(key)); + } + block.getByPosition(result) = join->joinGet(keys, result_block); +} + +template +ExecutableFunctionImplPtr FunctionJoinGet::prepare(const Block &, const ColumnNumbers &, size_t) const +{ + return std::make_unique>(join, Block{{return_type->createColumn(), return_type, attr_name}}); +} + static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context) { - if (arguments.size() != 3) - throw Exception{"Function joinGet takes 3 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; - String join_name; if (const auto * name_col = checkAndGetColumnConst(arguments[0].column.get())) { join_name = name_col->getValue(); } else - throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.", - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + throw Exception( + "Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); size_t dot = join_name.find('.'); String database_name; @@ -43,10 +59,12 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co ++dot; } String table_name = join_name.substr(dot); + if (table_name.empty()) + throw Exception("joinGet does not allow empty table name", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); auto table = DatabaseCatalog::instance().getTable({database_name, table_name}, context); auto storage_join = std::dynamic_pointer_cast(table); if (!storage_join) - throw Exception{"Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + throw Exception("Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); String attr_name; if (const auto * name_col = checkAndGetColumnConst(arguments[1].column.get())) @@ -54,57 +72,30 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co attr_name = name_col->getValue(); } else - throw Exception{"Illegal type " + arguments[1].type->getName() - + " of second argument of function joinGet, expected a const string.", - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + throw Exception( + "Illegal type " + arguments[1].type->getName() + " of second argument of function joinGet, expected a const string.", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return std::make_pair(storage_join, attr_name); } template FunctionBaseImplPtr JoinGetOverloadResolver::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const { + if (arguments.size() < 3) + throw Exception( + "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + + ", should be greater or equal to 3", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); auto [storage_join, attr_name] = getJoin(arguments, context); auto join = storage_join->getJoin(); - DataTypes data_types(arguments.size()); - + DataTypes data_types(arguments.size() - 2); + for (size_t i = 2; i < arguments.size(); ++i) + data_types[i - 2] = arguments[i].type; + auto return_type = join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null); auto table_lock = storage_join->lockForShare(context.getInitialQueryId(), context.getSettingsRef().lock_acquire_timeout); - for (size_t i = 0; i < arguments.size(); ++i) - data_types[i] = arguments[i].type; - - auto return_type = join->joinGetReturnType(attr_name, or_null); return std::make_unique>(table_lock, storage_join, join, attr_name, data_types, return_type); } -template -DataTypePtr JoinGetOverloadResolver::getReturnType(const ColumnsWithTypeAndName & arguments) const -{ - auto [storage_join, attr_name] = getJoin(arguments, context); - auto join = storage_join->getJoin(); - return join->joinGetReturnType(attr_name, or_null); -} - - -template -void ExecutableFunctionJoinGet::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) -{ - auto ctn = block.getByPosition(arguments[2]); - if (isColumnConst(*ctn.column)) - ctn.column = ctn.column->cloneResized(1); - ctn.name = ""; // make sure the key name never collide with the join columns - Block key_block = {ctn}; - join->joinGet(key_block, attr_name, or_null); - auto & result_ctn = key_block.getByPosition(1); - if (isColumnConst(*ctn.column)) - result_ctn.column = ColumnConst::create(result_ctn.column, input_rows_count); - block.getByPosition(result) = result_ctn; -} - -template -ExecutableFunctionImplPtr FunctionJoinGet::prepare(const Block &, const ColumnNumbers &, size_t) const -{ - return std::make_unique>(join, attr_name); -} - void registerFunctionJoinGet(FunctionFactory & factory) { // joinGet diff --git a/src/Functions/FunctionJoinGet.h b/src/Functions/FunctionJoinGet.h index a82da589960..6b3b1202f60 100644 --- a/src/Functions/FunctionJoinGet.h +++ b/src/Functions/FunctionJoinGet.h @@ -13,14 +13,14 @@ template class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl { public: - ExecutableFunctionJoinGet(HashJoinPtr join_, String attr_name_) - : join(std::move(join_)), attr_name(std::move(attr_name_)) {} + ExecutableFunctionJoinGet(HashJoinPtr join_, const Block & result_block_) + : join(std::move(join_)), result_block(result_block_) {} static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet"; bool useDefaultImplementationForNulls() const override { return false; } - bool useDefaultImplementationForConstants() const override { return true; } bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override; @@ -28,7 +28,7 @@ public: private: HashJoinPtr join; - const String attr_name; + Block result_block; }; template @@ -77,13 +77,14 @@ public: String getName() const override { return name; } FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override; - DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const override; + DataTypePtr getReturnType(const ColumnsWithTypeAndName &) const override { return {}; } // Not used bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0, 1}; } private: const Context & context; diff --git a/src/Interpreters/HashJoin.cpp b/src/Interpreters/HashJoin.cpp index 27294a57675..ffc806b9e88 100644 --- a/src/Interpreters/HashJoin.cpp +++ b/src/Interpreters/HashJoin.cpp @@ -42,6 +42,7 @@ namespace ErrorCodes extern const int SYNTAX_ERROR; extern const int SET_SIZE_LIMIT_EXCEEDED; extern const int TYPE_MISMATCH; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } namespace @@ -1109,27 +1110,34 @@ void HashJoin::joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) block = block.cloneWithColumns(std::move(dst_columns)); } -static void checkTypeOfKey(const Block & block_left, const Block & block_right) -{ - const auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0); - const auto & [c2, right_type_origin, right_name] = block_right.safeGetByPosition(0); - auto left_type = removeNullable(left_type_origin); - auto right_type = removeNullable(right_type_origin); - if (!left_type->equals(*right_type)) - throw Exception("Type mismatch of columns to joinGet by: " - + left_name + " " + left_type->getName() + " at left, " - + right_name + " " + right_type->getName() + " at right", - ErrorCodes::TYPE_MISMATCH); -} - - -DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null) const +DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const { std::shared_lock lock(data->rwlock); + size_t num_keys = data_types.size(); + if (right_table_keys.columns() != num_keys) + throw Exception( + "Number of arguments for function joinGet" + toString(or_null ? "OrNull" : "") + + " doesn't match: passed, should be equal to " + toString(num_keys), + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + for (size_t i = 0; i < num_keys; ++i) + { + const auto & left_type_origin = data_types[i]; + const auto & [c2, right_type_origin, right_name] = right_table_keys.safeGetByPosition(i); + auto left_type = removeNullable(left_type_origin); + auto right_type = removeNullable(right_type_origin); + if (!left_type->equals(*right_type)) + throw Exception( + "Type mismatch in joinGet key " + toString(i) + ": found type " + left_type->getName() + ", while the needed type is " + + right_type->getName(), + ErrorCodes::TYPE_MISMATCH); + } + if (!sample_block_with_columns_to_add.has(column_name)) throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::NO_SUCH_COLUMN_IN_TABLE); + auto elem = sample_block_with_columns_to_add.getByName(column_name); if (or_null) elem.type = makeNullable(elem.type); @@ -1138,34 +1146,33 @@ DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null template -void HashJoin::joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const +ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const { - joinBlockImpl( - block, {block.getByPosition(0).name}, block_with_columns_to_add, maps_); + // Assemble the key block with correct names. + Block keys; + for (size_t i = 0; i < block.columns(); ++i) + { + auto key = block.getByPosition(i); + key.name = key_names_right[i]; + keys.insert(std::move(key)); + } + + joinBlockImpl( + keys, key_names_right, block_with_columns_to_add, maps_); + return keys.getByPosition(keys.columns() - 1); } -// TODO: support composite key // TODO: return multiple columns as named tuple // TODO: return array of values when strictness == ASTTableJoin::Strictness::All -void HashJoin::joinGet(Block & block, const String & column_name, bool or_null) const +ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const { std::shared_lock lock(data->rwlock); - if (key_names_right.size() != 1) - throw Exception("joinGet only supports StorageJoin containing exactly one key", ErrorCodes::UNSUPPORTED_JOIN_KEYS); - - checkTypeOfKey(block, right_table_keys); - - auto elem = sample_block_with_columns_to_add.getByName(column_name); - if (or_null) - elem.type = makeNullable(elem.type); - elem.column = elem.type->createColumn(); - if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) && kind == ASTTableJoin::Kind::Left) { - joinGetImpl(block, {elem}, std::get(data->maps)); + return joinGetImpl(block, block_with_columns_to_add, std::get(data->maps)); } else throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN); diff --git a/src/Interpreters/HashJoin.h b/src/Interpreters/HashJoin.h index 67d83d27a6d..025f41ac28f 100644 --- a/src/Interpreters/HashJoin.h +++ b/src/Interpreters/HashJoin.h @@ -162,11 +162,11 @@ public: */ void joinBlock(Block & block, ExtraBlockPtr & not_processed) override; - /// Infer the return type for joinGet function - DataTypePtr joinGetReturnType(const String & column_name, bool or_null) const; + /// Check joinGet arguments and infer the return type. + DataTypePtr joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const; - /// Used by joinGet function that turns StorageJoin into a dictionary - void joinGet(Block & block, const String & column_name, bool or_null) const; + /// Used by joinGet function that turns StorageJoin into a dictionary. + ColumnWithTypeAndName joinGet(const Block & block, const Block & block_with_columns_to_add) const; /** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later. */ @@ -383,7 +383,7 @@ private: void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const; template - void joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const; + ColumnWithTypeAndName joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const; static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); }; diff --git a/src/Interpreters/misc.h b/src/Interpreters/misc.h index 094dfbbbb81..cae2691ca1f 100644 --- a/src/Interpreters/misc.h +++ b/src/Interpreters/misc.h @@ -28,7 +28,7 @@ inline bool functionIsLikeOperator(const std::string & name) inline bool functionIsJoinGet(const std::string & name) { - return name == "joinGet" || startsWith(name, "dictGet"); + return startsWith(name, "joinGet"); } inline bool functionIsDictGet(const std::string & name) diff --git a/tests/queries/0_stateless/01080_join_get_null.reference b/tests/queries/0_stateless/01080_join_get_null.reference index bfde072a796..0cfbf08886f 100644 --- a/tests/queries/0_stateless/01080_join_get_null.reference +++ b/tests/queries/0_stateless/01080_join_get_null.reference @@ -1 +1 @@ -2 2 +2 diff --git a/tests/queries/0_stateless/01080_join_get_null.sql b/tests/queries/0_stateless/01080_join_get_null.sql index 71e7ddf8e75..9f782452d34 100644 --- a/tests/queries/0_stateless/01080_join_get_null.sql +++ b/tests/queries/0_stateless/01080_join_get_null.sql @@ -1,12 +1,12 @@ DROP TABLE IF EXISTS test_joinGet; -DROP TABLE IF EXISTS test_join_joinGet; -CREATE TABLE test_joinGet(id Int32, user_id Nullable(Int32)) Engine = Memory(); -CREATE TABLE test_join_joinGet(user_id Int32, name String) Engine = Join(ANY, LEFT, user_id); +CREATE TABLE test_joinGet(user_id Nullable(Int32), name String) Engine = Join(ANY, LEFT, user_id); -INSERT INTO test_join_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c'); +INSERT INTO test_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c'), (null, 'd'); -SELECT 2 id, toNullable(toInt32(2)) user_id WHERE joinGet(test_join_joinGet, 'name', user_id) != ''; +SELECT toNullable(toInt32(2)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != ''; + +-- If the JOIN keys are Nullable fields, the rows where at least one of the keys has the value NULL are not joined. +SELECT cast(null AS Nullable(Int32)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != ''; DROP TABLE test_joinGet; -DROP TABLE test_join_joinGet; diff --git a/tests/queries/0_stateless/01400_join_get_with_multi_keys.reference b/tests/queries/0_stateless/01400_join_get_with_multi_keys.reference new file mode 100644 index 00000000000..49d59571fbf --- /dev/null +++ b/tests/queries/0_stateless/01400_join_get_with_multi_keys.reference @@ -0,0 +1 @@ +0.1 diff --git a/tests/queries/0_stateless/01400_join_get_with_multi_keys.sql b/tests/queries/0_stateless/01400_join_get_with_multi_keys.sql new file mode 100644 index 00000000000..73068270762 --- /dev/null +++ b/tests/queries/0_stateless/01400_join_get_with_multi_keys.sql @@ -0,0 +1,9 @@ +DROP TABLE IF EXISTS test_joinGet; + +CREATE TABLE test_joinGet(a String, b String, c Float64) ENGINE = Join(any, left, a, b); + +INSERT INTO test_joinGet VALUES ('ab', '1', 0.1), ('ab', '2', 0.2), ('cd', '3', 0.3); + +SELECT joinGet(test_joinGet, 'c', 'ab', '1'); + +DROP TABLE test_joinGet;