Refactor joinGet and implement multi-key lookup.

This commit is contained in:
Amos Bird 2020-07-11 15:12:42 +08:00
parent ba7c33f806
commit 230938d3a3
No known key found for this signature in database
GPG Key ID: 80D430DCBECFEDB4
9 changed files with 104 additions and 95 deletions

View File

@ -1,10 +1,10 @@
#include <Functions/FunctionJoinGet.h> #include <Functions/FunctionJoinGet.h>
#include <Columns/ColumnString.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/HashJoin.h> #include <Interpreters/HashJoin.h>
#include <Columns/ColumnString.h>
#include <Storages/StorageJoin.h> #include <Storages/StorageJoin.h>
@ -16,19 +16,35 @@ namespace ErrorCodes
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
} }
template <bool or_null>
void ExecutableFunctionJoinGet<or_null>::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t)
{
Block keys;
for (size_t i = 2; i < arguments.size(); ++i)
{
auto key = block.getByPosition(arguments[i]);
keys.insert(std::move(key));
}
block.getByPosition(result) = join->joinGet(keys, result_block);
}
template <bool or_null>
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const Block &, const ColumnNumbers &, size_t) const
{
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, Block{{return_type->createColumn(), return_type, attr_name}});
}
static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context) static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context)
{ {
if (arguments.size() != 3)
throw Exception{"Function joinGet takes 3 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String join_name; String join_name;
if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get())) if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
{ {
join_name = name_col->getValue<String>(); join_name = name_col->getValue<String>();
} }
else else
throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.", throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; "Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
size_t dot = join_name.find('.'); size_t dot = join_name.find('.');
String database_name; String database_name;
@ -43,10 +59,12 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co
++dot; ++dot;
} }
String table_name = join_name.substr(dot); String table_name = join_name.substr(dot);
if (table_name.empty())
throw Exception("joinGet does not allow empty table name", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto table = DatabaseCatalog::instance().getTable({database_name, table_name}, context); auto table = DatabaseCatalog::instance().getTable({database_name, table_name}, context);
auto storage_join = std::dynamic_pointer_cast<StorageJoin>(table); auto storage_join = std::dynamic_pointer_cast<StorageJoin>(table);
if (!storage_join) if (!storage_join)
throw Exception{"Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; throw Exception("Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
String attr_name; String attr_name;
if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get())) if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
@ -54,57 +72,30 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co
attr_name = name_col->getValue<String>(); attr_name = name_col->getValue<String>();
} }
else else
throw Exception{"Illegal type " + arguments[1].type->getName() throw Exception(
+ " of second argument of function joinGet, expected a const string.", "Illegal type " + arguments[1].type->getName() + " of second argument of function joinGet, expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_pair(storage_join, attr_name); return std::make_pair(storage_join, attr_name);
} }
template <bool or_null> template <bool or_null>
FunctionBaseImplPtr JoinGetOverloadResolver<or_null>::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const FunctionBaseImplPtr JoinGetOverloadResolver<or_null>::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
{ {
if (arguments.size() < 3)
throw Exception(
"Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size())
+ ", should be greater or equal to 3",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto [storage_join, attr_name] = getJoin(arguments, context); auto [storage_join, attr_name] = getJoin(arguments, context);
auto join = storage_join->getJoin(); auto join = storage_join->getJoin();
DataTypes data_types(arguments.size()); DataTypes data_types(arguments.size() - 2);
for (size_t i = 2; i < arguments.size(); ++i)
data_types[i - 2] = arguments[i].type;
auto return_type = join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null);
auto table_lock = storage_join->lockForShare(context.getInitialQueryId(), context.getSettingsRef().lock_acquire_timeout); auto table_lock = storage_join->lockForShare(context.getInitialQueryId(), context.getSettingsRef().lock_acquire_timeout);
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
auto return_type = join->joinGetReturnType(attr_name, or_null);
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, join, attr_name, data_types, return_type); return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, join, attr_name, data_types, return_type);
} }
template <bool or_null>
DataTypePtr JoinGetOverloadResolver<or_null>::getReturnType(const ColumnsWithTypeAndName & arguments) const
{
auto [storage_join, attr_name] = getJoin(arguments, context);
auto join = storage_join->getJoin();
return join->joinGetReturnType(attr_name, or_null);
}
template <bool or_null>
void ExecutableFunctionJoinGet<or_null>::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
auto ctn = block.getByPosition(arguments[2]);
if (isColumnConst(*ctn.column))
ctn.column = ctn.column->cloneResized(1);
ctn.name = ""; // make sure the key name never collide with the join columns
Block key_block = {ctn};
join->joinGet(key_block, attr_name, or_null);
auto & result_ctn = key_block.getByPosition(1);
if (isColumnConst(*ctn.column))
result_ctn.column = ColumnConst::create(result_ctn.column, input_rows_count);
block.getByPosition(result) = result_ctn;
}
template <bool or_null>
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const Block &, const ColumnNumbers &, size_t) const
{
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, attr_name);
}
void registerFunctionJoinGet(FunctionFactory & factory) void registerFunctionJoinGet(FunctionFactory & factory)
{ {
// joinGet // joinGet

View File

@ -13,14 +13,14 @@ template <bool or_null>
class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl
{ {
public: public:
ExecutableFunctionJoinGet(HashJoinPtr join_, String attr_name_) ExecutableFunctionJoinGet(HashJoinPtr join_, const Block & result_block_)
: join(std::move(join_)), attr_name(std::move(attr_name_)) {} : join(std::move(join_)), result_block(result_block_) {}
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet"; static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } bool useDefaultImplementationForLowCardinalityColumns() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override; void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
@ -28,7 +28,7 @@ public:
private: private:
HashJoinPtr join; HashJoinPtr join;
const String attr_name; Block result_block;
}; };
template <bool or_null> template <bool or_null>
@ -77,13 +77,14 @@ public:
String getName() const override { return name; } String getName() const override { return name; }
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override; FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override;
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const override; DataTypePtr getReturnType(const ColumnsWithTypeAndName &) const override { return {}; } // Not used
bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } bool useDefaultImplementationForLowCardinalityColumns() const override { return true; }
bool isVariadic() const override { return true; } bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; } size_t getNumberOfArguments() const override { return 0; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0, 1}; }
private: private:
const Context & context; const Context & context;

View File

@ -42,6 +42,7 @@ namespace ErrorCodes
extern const int SYNTAX_ERROR; extern const int SYNTAX_ERROR;
extern const int SET_SIZE_LIMIT_EXCEEDED; extern const int SET_SIZE_LIMIT_EXCEEDED;
extern const int TYPE_MISMATCH; extern const int TYPE_MISMATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
} }
namespace namespace
@ -1109,27 +1110,34 @@ void HashJoin::joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed)
block = block.cloneWithColumns(std::move(dst_columns)); block = block.cloneWithColumns(std::move(dst_columns));
} }
static void checkTypeOfKey(const Block & block_left, const Block & block_right)
{
const auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0);
const auto & [c2, right_type_origin, right_name] = block_right.safeGetByPosition(0);
auto left_type = removeNullable(left_type_origin);
auto right_type = removeNullable(right_type_origin);
if (!left_type->equals(*right_type)) DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const
throw Exception("Type mismatch of columns to joinGet by: "
+ left_name + " " + left_type->getName() + " at left, "
+ right_name + " " + right_type->getName() + " at right",
ErrorCodes::TYPE_MISMATCH);
}
DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null) const
{ {
std::shared_lock lock(data->rwlock); std::shared_lock lock(data->rwlock);
size_t num_keys = data_types.size();
if (right_table_keys.columns() != num_keys)
throw Exception(
"Number of arguments for function joinGet" + toString(or_null ? "OrNull" : "")
+ " doesn't match: passed, should be equal to " + toString(num_keys),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < num_keys; ++i)
{
const auto & left_type_origin = data_types[i];
const auto & [c2, right_type_origin, right_name] = right_table_keys.safeGetByPosition(i);
auto left_type = removeNullable(left_type_origin);
auto right_type = removeNullable(right_type_origin);
if (!left_type->equals(*right_type))
throw Exception(
"Type mismatch in joinGet key " + toString(i) + ": found type " + left_type->getName() + ", while the needed type is "
+ right_type->getName(),
ErrorCodes::TYPE_MISMATCH);
}
if (!sample_block_with_columns_to_add.has(column_name)) if (!sample_block_with_columns_to_add.has(column_name))
throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::NO_SUCH_COLUMN_IN_TABLE); throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::NO_SUCH_COLUMN_IN_TABLE);
auto elem = sample_block_with_columns_to_add.getByName(column_name); auto elem = sample_block_with_columns_to_add.getByName(column_name);
if (or_null) if (or_null)
elem.type = makeNullable(elem.type); elem.type = makeNullable(elem.type);
@ -1138,34 +1146,33 @@ DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null
template <typename Maps> template <typename Maps>
void HashJoin::joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const
{ {
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::RightAny>( // Assemble the key block with correct names.
block, {block.getByPosition(0).name}, block_with_columns_to_add, maps_); Block keys;
for (size_t i = 0; i < block.columns(); ++i)
{
auto key = block.getByPosition(i);
key.name = key_names_right[i];
keys.insert(std::move(key));
}
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
keys, key_names_right, block_with_columns_to_add, maps_);
return keys.getByPosition(keys.columns() - 1);
} }
// TODO: support composite key
// TODO: return multiple columns as named tuple // TODO: return multiple columns as named tuple
// TODO: return array of values when strictness == ASTTableJoin::Strictness::All // TODO: return array of values when strictness == ASTTableJoin::Strictness::All
void HashJoin::joinGet(Block & block, const String & column_name, bool or_null) const ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const
{ {
std::shared_lock lock(data->rwlock); std::shared_lock lock(data->rwlock);
if (key_names_right.size() != 1)
throw Exception("joinGet only supports StorageJoin containing exactly one key", ErrorCodes::UNSUPPORTED_JOIN_KEYS);
checkTypeOfKey(block, right_table_keys);
auto elem = sample_block_with_columns_to_add.getByName(column_name);
if (or_null)
elem.type = makeNullable(elem.type);
elem.column = elem.type->createColumn();
if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) && if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) &&
kind == ASTTableJoin::Kind::Left) kind == ASTTableJoin::Kind::Left)
{ {
joinGetImpl(block, {elem}, std::get<MapsOne>(data->maps)); return joinGetImpl(block, block_with_columns_to_add, std::get<MapsOne>(data->maps));
} }
else else
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN); throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);

View File

@ -162,11 +162,11 @@ public:
*/ */
void joinBlock(Block & block, ExtraBlockPtr & not_processed) override; void joinBlock(Block & block, ExtraBlockPtr & not_processed) override;
/// Infer the return type for joinGet function /// Check joinGet arguments and infer the return type.
DataTypePtr joinGetReturnType(const String & column_name, bool or_null) const; DataTypePtr joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const;
/// Used by joinGet function that turns StorageJoin into a dictionary /// Used by joinGet function that turns StorageJoin into a dictionary.
void joinGet(Block & block, const String & column_name, bool or_null) const; ColumnWithTypeAndName joinGet(const Block & block, const Block & block_with_columns_to_add) const;
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later. /** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
*/ */
@ -383,7 +383,7 @@ private:
void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const; void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const;
template <typename Maps> template <typename Maps>
void joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const; ColumnWithTypeAndName joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const;
static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes);
}; };

View File

@ -28,7 +28,7 @@ inline bool functionIsLikeOperator(const std::string & name)
inline bool functionIsJoinGet(const std::string & name) inline bool functionIsJoinGet(const std::string & name)
{ {
return name == "joinGet" || startsWith(name, "dictGet"); return startsWith(name, "joinGet");
} }
inline bool functionIsDictGet(const std::string & name) inline bool functionIsDictGet(const std::string & name)

View File

@ -1,12 +1,12 @@
DROP TABLE IF EXISTS test_joinGet; DROP TABLE IF EXISTS test_joinGet;
DROP TABLE IF EXISTS test_join_joinGet;
CREATE TABLE test_joinGet(id Int32, user_id Nullable(Int32)) Engine = Memory(); CREATE TABLE test_joinGet(user_id Nullable(Int32), name String) Engine = Join(ANY, LEFT, user_id);
CREATE TABLE test_join_joinGet(user_id Int32, name String) Engine = Join(ANY, LEFT, user_id);
INSERT INTO test_join_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c'); INSERT INTO test_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c'), (null, 'd');
SELECT 2 id, toNullable(toInt32(2)) user_id WHERE joinGet(test_join_joinGet, 'name', user_id) != ''; SELECT toNullable(toInt32(2)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != '';
-- If the JOIN keys are Nullable fields, the rows where at least one of the keys has the value NULL are not joined.
SELECT cast(null AS Nullable(Int32)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != '';
DROP TABLE test_joinGet; DROP TABLE test_joinGet;
DROP TABLE test_join_joinGet;

View File

@ -0,0 +1 @@
0.1

View File

@ -0,0 +1,9 @@
DROP TABLE IF EXISTS test_joinGet;
CREATE TABLE test_joinGet(a String, b String, c Float64) ENGINE = Join(any, left, a, b);
INSERT INTO test_joinGet VALUES ('ab', '1', 0.1), ('ab', '2', 0.2), ('cd', '3', 0.3);
SELECT joinGet(test_joinGet, 'c', 'ab', '1');
DROP TABLE test_joinGet;