mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-17 20:02:05 +00:00
Refactor joinGet and implement multi-key lookup.
This commit is contained in:
parent
ba7c33f806
commit
230938d3a3
@ -1,10 +1,10 @@
|
||||
#include <Functions/FunctionJoinGet.h>
|
||||
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/HashJoin.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Storages/StorageJoin.h>
|
||||
|
||||
|
||||
@ -16,19 +16,35 @@ namespace ErrorCodes
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
template <bool or_null>
|
||||
void ExecutableFunctionJoinGet<or_null>::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t)
|
||||
{
|
||||
Block keys;
|
||||
for (size_t i = 2; i < arguments.size(); ++i)
|
||||
{
|
||||
auto key = block.getByPosition(arguments[i]);
|
||||
keys.insert(std::move(key));
|
||||
}
|
||||
block.getByPosition(result) = join->joinGet(keys, result_block);
|
||||
}
|
||||
|
||||
template <bool or_null>
|
||||
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const Block &, const ColumnNumbers &, size_t) const
|
||||
{
|
||||
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, Block{{return_type->createColumn(), return_type, attr_name}});
|
||||
}
|
||||
|
||||
static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context)
|
||||
{
|
||||
if (arguments.size() != 3)
|
||||
throw Exception{"Function joinGet takes 3 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
|
||||
|
||||
String join_name;
|
||||
if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
|
||||
{
|
||||
join_name = name_col->getValue<String>();
|
||||
}
|
||||
else
|
||||
throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
throw Exception(
|
||||
"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
size_t dot = join_name.find('.');
|
||||
String database_name;
|
||||
@ -43,10 +59,12 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co
|
||||
++dot;
|
||||
}
|
||||
String table_name = join_name.substr(dot);
|
||||
if (table_name.empty())
|
||||
throw Exception("joinGet does not allow empty table name", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
auto table = DatabaseCatalog::instance().getTable({database_name, table_name}, context);
|
||||
auto storage_join = std::dynamic_pointer_cast<StorageJoin>(table);
|
||||
if (!storage_join)
|
||||
throw Exception{"Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
throw Exception("Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
String attr_name;
|
||||
if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
|
||||
@ -54,57 +72,30 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co
|
||||
attr_name = name_col->getValue<String>();
|
||||
}
|
||||
else
|
||||
throw Exception{"Illegal type " + arguments[1].type->getName()
|
||||
+ " of second argument of function joinGet, expected a const string.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
throw Exception(
|
||||
"Illegal type " + arguments[1].type->getName() + " of second argument of function joinGet, expected a const string.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
return std::make_pair(storage_join, attr_name);
|
||||
}
|
||||
|
||||
template <bool or_null>
|
||||
FunctionBaseImplPtr JoinGetOverloadResolver<or_null>::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
|
||||
{
|
||||
if (arguments.size() < 3)
|
||||
throw Exception(
|
||||
"Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size())
|
||||
+ ", should be greater or equal to 3",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
auto [storage_join, attr_name] = getJoin(arguments, context);
|
||||
auto join = storage_join->getJoin();
|
||||
DataTypes data_types(arguments.size());
|
||||
|
||||
DataTypes data_types(arguments.size() - 2);
|
||||
for (size_t i = 2; i < arguments.size(); ++i)
|
||||
data_types[i - 2] = arguments[i].type;
|
||||
auto return_type = join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null);
|
||||
auto table_lock = storage_join->lockForShare(context.getInitialQueryId(), context.getSettingsRef().lock_acquire_timeout);
|
||||
for (size_t i = 0; i < arguments.size(); ++i)
|
||||
data_types[i] = arguments[i].type;
|
||||
|
||||
auto return_type = join->joinGetReturnType(attr_name, or_null);
|
||||
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, join, attr_name, data_types, return_type);
|
||||
}
|
||||
|
||||
template <bool or_null>
|
||||
DataTypePtr JoinGetOverloadResolver<or_null>::getReturnType(const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
auto [storage_join, attr_name] = getJoin(arguments, context);
|
||||
auto join = storage_join->getJoin();
|
||||
return join->joinGetReturnType(attr_name, or_null);
|
||||
}
|
||||
|
||||
|
||||
template <bool or_null>
|
||||
void ExecutableFunctionJoinGet<or_null>::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
|
||||
{
|
||||
auto ctn = block.getByPosition(arguments[2]);
|
||||
if (isColumnConst(*ctn.column))
|
||||
ctn.column = ctn.column->cloneResized(1);
|
||||
ctn.name = ""; // make sure the key name never collide with the join columns
|
||||
Block key_block = {ctn};
|
||||
join->joinGet(key_block, attr_name, or_null);
|
||||
auto & result_ctn = key_block.getByPosition(1);
|
||||
if (isColumnConst(*ctn.column))
|
||||
result_ctn.column = ColumnConst::create(result_ctn.column, input_rows_count);
|
||||
block.getByPosition(result) = result_ctn;
|
||||
}
|
||||
|
||||
template <bool or_null>
|
||||
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const Block &, const ColumnNumbers &, size_t) const
|
||||
{
|
||||
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, attr_name);
|
||||
}
|
||||
|
||||
void registerFunctionJoinGet(FunctionFactory & factory)
|
||||
{
|
||||
// joinGet
|
||||
|
@ -13,14 +13,14 @@ template <bool or_null>
|
||||
class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl
|
||||
{
|
||||
public:
|
||||
ExecutableFunctionJoinGet(HashJoinPtr join_, String attr_name_)
|
||||
: join(std::move(join_)), attr_name(std::move(attr_name_)) {}
|
||||
ExecutableFunctionJoinGet(HashJoinPtr join_, const Block & result_block_)
|
||||
: join(std::move(join_)), result_block(result_block_) {}
|
||||
|
||||
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
|
||||
|
||||
bool useDefaultImplementationForNulls() const override { return false; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
bool useDefaultImplementationForLowCardinalityColumns() const override { return true; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
|
||||
|
||||
@ -28,7 +28,7 @@ public:
|
||||
|
||||
private:
|
||||
HashJoinPtr join;
|
||||
const String attr_name;
|
||||
Block result_block;
|
||||
};
|
||||
|
||||
template <bool or_null>
|
||||
@ -77,13 +77,14 @@ public:
|
||||
String getName() const override { return name; }
|
||||
|
||||
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override;
|
||||
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const override;
|
||||
DataTypePtr getReturnType(const ColumnsWithTypeAndName &) const override { return {}; } // Not used
|
||||
|
||||
bool useDefaultImplementationForNulls() const override { return false; }
|
||||
bool useDefaultImplementationForLowCardinalityColumns() const override { return true; }
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0, 1}; }
|
||||
|
||||
private:
|
||||
const Context & context;
|
||||
|
@ -42,6 +42,7 @@ namespace ErrorCodes
|
||||
extern const int SYNTAX_ERROR;
|
||||
extern const int SET_SIZE_LIMIT_EXCEEDED;
|
||||
extern const int TYPE_MISMATCH;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
namespace
|
||||
@ -1109,27 +1110,34 @@ void HashJoin::joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed)
|
||||
block = block.cloneWithColumns(std::move(dst_columns));
|
||||
}
|
||||
|
||||
static void checkTypeOfKey(const Block & block_left, const Block & block_right)
|
||||
{
|
||||
const auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0);
|
||||
const auto & [c2, right_type_origin, right_name] = block_right.safeGetByPosition(0);
|
||||
auto left_type = removeNullable(left_type_origin);
|
||||
auto right_type = removeNullable(right_type_origin);
|
||||
|
||||
if (!left_type->equals(*right_type))
|
||||
throw Exception("Type mismatch of columns to joinGet by: "
|
||||
+ left_name + " " + left_type->getName() + " at left, "
|
||||
+ right_name + " " + right_type->getName() + " at right",
|
||||
ErrorCodes::TYPE_MISMATCH);
|
||||
}
|
||||
|
||||
|
||||
DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null) const
|
||||
DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const
|
||||
{
|
||||
std::shared_lock lock(data->rwlock);
|
||||
|
||||
size_t num_keys = data_types.size();
|
||||
if (right_table_keys.columns() != num_keys)
|
||||
throw Exception(
|
||||
"Number of arguments for function joinGet" + toString(or_null ? "OrNull" : "")
|
||||
+ " doesn't match: passed, should be equal to " + toString(num_keys),
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
for (size_t i = 0; i < num_keys; ++i)
|
||||
{
|
||||
const auto & left_type_origin = data_types[i];
|
||||
const auto & [c2, right_type_origin, right_name] = right_table_keys.safeGetByPosition(i);
|
||||
auto left_type = removeNullable(left_type_origin);
|
||||
auto right_type = removeNullable(right_type_origin);
|
||||
if (!left_type->equals(*right_type))
|
||||
throw Exception(
|
||||
"Type mismatch in joinGet key " + toString(i) + ": found type " + left_type->getName() + ", while the needed type is "
|
||||
+ right_type->getName(),
|
||||
ErrorCodes::TYPE_MISMATCH);
|
||||
}
|
||||
|
||||
if (!sample_block_with_columns_to_add.has(column_name))
|
||||
throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::NO_SUCH_COLUMN_IN_TABLE);
|
||||
|
||||
auto elem = sample_block_with_columns_to_add.getByName(column_name);
|
||||
if (or_null)
|
||||
elem.type = makeNullable(elem.type);
|
||||
@ -1138,34 +1146,33 @@ DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null
|
||||
|
||||
|
||||
template <typename Maps>
|
||||
void HashJoin::joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const
|
||||
ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const
|
||||
{
|
||||
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::RightAny>(
|
||||
block, {block.getByPosition(0).name}, block_with_columns_to_add, maps_);
|
||||
// Assemble the key block with correct names.
|
||||
Block keys;
|
||||
for (size_t i = 0; i < block.columns(); ++i)
|
||||
{
|
||||
auto key = block.getByPosition(i);
|
||||
key.name = key_names_right[i];
|
||||
keys.insert(std::move(key));
|
||||
}
|
||||
|
||||
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
|
||||
keys, key_names_right, block_with_columns_to_add, maps_);
|
||||
return keys.getByPosition(keys.columns() - 1);
|
||||
}
|
||||
|
||||
|
||||
// TODO: support composite key
|
||||
// TODO: return multiple columns as named tuple
|
||||
// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
|
||||
void HashJoin::joinGet(Block & block, const String & column_name, bool or_null) const
|
||||
ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const
|
||||
{
|
||||
std::shared_lock lock(data->rwlock);
|
||||
|
||||
if (key_names_right.size() != 1)
|
||||
throw Exception("joinGet only supports StorageJoin containing exactly one key", ErrorCodes::UNSUPPORTED_JOIN_KEYS);
|
||||
|
||||
checkTypeOfKey(block, right_table_keys);
|
||||
|
||||
auto elem = sample_block_with_columns_to_add.getByName(column_name);
|
||||
if (or_null)
|
||||
elem.type = makeNullable(elem.type);
|
||||
elem.column = elem.type->createColumn();
|
||||
|
||||
if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) &&
|
||||
kind == ASTTableJoin::Kind::Left)
|
||||
{
|
||||
joinGetImpl(block, {elem}, std::get<MapsOne>(data->maps));
|
||||
return joinGetImpl(block, block_with_columns_to_add, std::get<MapsOne>(data->maps));
|
||||
}
|
||||
else
|
||||
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);
|
||||
|
@ -162,11 +162,11 @@ public:
|
||||
*/
|
||||
void joinBlock(Block & block, ExtraBlockPtr & not_processed) override;
|
||||
|
||||
/// Infer the return type for joinGet function
|
||||
DataTypePtr joinGetReturnType(const String & column_name, bool or_null) const;
|
||||
/// Check joinGet arguments and infer the return type.
|
||||
DataTypePtr joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const;
|
||||
|
||||
/// Used by joinGet function that turns StorageJoin into a dictionary
|
||||
void joinGet(Block & block, const String & column_name, bool or_null) const;
|
||||
/// Used by joinGet function that turns StorageJoin into a dictionary.
|
||||
ColumnWithTypeAndName joinGet(const Block & block, const Block & block_with_columns_to_add) const;
|
||||
|
||||
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
|
||||
*/
|
||||
@ -383,7 +383,7 @@ private:
|
||||
void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const;
|
||||
|
||||
template <typename Maps>
|
||||
void joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const;
|
||||
ColumnWithTypeAndName joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const;
|
||||
|
||||
static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes);
|
||||
};
|
||||
|
@ -28,7 +28,7 @@ inline bool functionIsLikeOperator(const std::string & name)
|
||||
|
||||
inline bool functionIsJoinGet(const std::string & name)
|
||||
{
|
||||
return name == "joinGet" || startsWith(name, "dictGet");
|
||||
return startsWith(name, "joinGet");
|
||||
}
|
||||
|
||||
inline bool functionIsDictGet(const std::string & name)
|
||||
|
@ -1 +1 @@
|
||||
2 2
|
||||
2
|
||||
|
@ -1,12 +1,12 @@
|
||||
DROP TABLE IF EXISTS test_joinGet;
|
||||
DROP TABLE IF EXISTS test_join_joinGet;
|
||||
|
||||
CREATE TABLE test_joinGet(id Int32, user_id Nullable(Int32)) Engine = Memory();
|
||||
CREATE TABLE test_join_joinGet(user_id Int32, name String) Engine = Join(ANY, LEFT, user_id);
|
||||
CREATE TABLE test_joinGet(user_id Nullable(Int32), name String) Engine = Join(ANY, LEFT, user_id);
|
||||
|
||||
INSERT INTO test_join_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c');
|
||||
INSERT INTO test_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c'), (null, 'd');
|
||||
|
||||
SELECT 2 id, toNullable(toInt32(2)) user_id WHERE joinGet(test_join_joinGet, 'name', user_id) != '';
|
||||
SELECT toNullable(toInt32(2)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != '';
|
||||
|
||||
-- If the JOIN keys are Nullable fields, the rows where at least one of the keys has the value NULL are not joined.
|
||||
SELECT cast(null AS Nullable(Int32)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != '';
|
||||
|
||||
DROP TABLE test_joinGet;
|
||||
DROP TABLE test_join_joinGet;
|
||||
|
@ -0,0 +1 @@
|
||||
0.1
|
@ -0,0 +1,9 @@
|
||||
DROP TABLE IF EXISTS test_joinGet;
|
||||
|
||||
CREATE TABLE test_joinGet(a String, b String, c Float64) ENGINE = Join(any, left, a, b);
|
||||
|
||||
INSERT INTO test_joinGet VALUES ('ab', '1', 0.1), ('ab', '2', 0.2), ('cd', '3', 0.3);
|
||||
|
||||
SELECT joinGet(test_joinGet, 'c', 'ab', '1');
|
||||
|
||||
DROP TABLE test_joinGet;
|
Loading…
Reference in New Issue
Block a user