Merge pull request #13015 from amosbird/jgmko

Refactor joinGet and implement multi-key lookup.
This commit is contained in:
Nikolai Kochetov 2020-09-13 22:39:48 +03:00 committed by GitHub
commit 109fd9d6d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 104 additions and 95 deletions

View File

@ -1,10 +1,10 @@
#include <Functions/FunctionJoinGet.h>
#include <Columns/ColumnString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/HashJoin.h>
#include <Columns/ColumnString.h>
#include <Storages/StorageJoin.h>
@ -16,19 +16,35 @@ namespace ErrorCodes
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
template <bool or_null>
void ExecutableFunctionJoinGet<or_null>::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t)
{
Block keys;
for (size_t i = 2; i < arguments.size(); ++i)
{
auto key = block.getByPosition(arguments[i]);
keys.insert(std::move(key));
}
block.getByPosition(result) = join->joinGet(keys, result_block);
}
template <bool or_null>
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const Block &, const ColumnNumbers &, size_t) const
{
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, Block{{return_type->createColumn(), return_type, attr_name}});
}
static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context)
{
if (arguments.size() != 3)
throw Exception{"Function joinGet takes 3 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String join_name;
if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
{
join_name = name_col->getValue<String>();
}
else
throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
throw Exception(
"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
size_t dot = join_name.find('.');
String database_name;
@ -43,10 +59,12 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co
++dot;
}
String table_name = join_name.substr(dot);
if (table_name.empty())
throw Exception("joinGet does not allow empty table name", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto table = DatabaseCatalog::instance().getTable({database_name, table_name}, context);
auto storage_join = std::dynamic_pointer_cast<StorageJoin>(table);
if (!storage_join)
throw Exception{"Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
throw Exception("Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
String attr_name;
if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
@ -54,57 +72,30 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co
attr_name = name_col->getValue<String>();
}
else
throw Exception{"Illegal type " + arguments[1].type->getName()
+ " of second argument of function joinGet, expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
throw Exception(
"Illegal type " + arguments[1].type->getName() + " of second argument of function joinGet, expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_pair(storage_join, attr_name);
}
template <bool or_null>
FunctionBaseImplPtr JoinGetOverloadResolver<or_null>::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
{
if (arguments.size() < 3)
throw Exception(
"Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size())
+ ", should be greater or equal to 3",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto [storage_join, attr_name] = getJoin(arguments, context);
auto join = storage_join->getJoin();
DataTypes data_types(arguments.size());
DataTypes data_types(arguments.size() - 2);
for (size_t i = 2; i < arguments.size(); ++i)
data_types[i - 2] = arguments[i].type;
auto return_type = join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null);
auto table_lock = storage_join->lockForShare(context.getInitialQueryId(), context.getSettingsRef().lock_acquire_timeout);
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
auto return_type = join->joinGetReturnType(attr_name, or_null);
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, join, attr_name, data_types, return_type);
}
template <bool or_null>
DataTypePtr JoinGetOverloadResolver<or_null>::getReturnType(const ColumnsWithTypeAndName & arguments) const
{
auto [storage_join, attr_name] = getJoin(arguments, context);
auto join = storage_join->getJoin();
return join->joinGetReturnType(attr_name, or_null);
}
template <bool or_null>
void ExecutableFunctionJoinGet<or_null>::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
auto ctn = block.getByPosition(arguments[2]);
if (isColumnConst(*ctn.column))
ctn.column = ctn.column->cloneResized(1);
ctn.name = ""; // make sure the key name never collide with the join columns
Block key_block = {ctn};
join->joinGet(key_block, attr_name, or_null);
auto & result_ctn = key_block.getByPosition(1);
if (isColumnConst(*ctn.column))
result_ctn.column = ColumnConst::create(result_ctn.column, input_rows_count);
block.getByPosition(result) = result_ctn;
}
template <bool or_null>
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const Block &, const ColumnNumbers &, size_t) const
{
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, attr_name);
}
void registerFunctionJoinGet(FunctionFactory & factory)
{
// joinGet

View File

@ -13,14 +13,14 @@ template <bool or_null>
class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl
{
public:
ExecutableFunctionJoinGet(HashJoinPtr join_, String attr_name_)
: join(std::move(join_)), attr_name(std::move(attr_name_)) {}
ExecutableFunctionJoinGet(HashJoinPtr join_, const Block & result_block_)
: join(std::move(join_)), result_block(result_block_) {}
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
@ -28,7 +28,7 @@ public:
private:
HashJoinPtr join;
const String attr_name;
Block result_block;
};
template <bool or_null>
@ -77,13 +77,14 @@ public:
String getName() const override { return name; }
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override;
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const override;
DataTypePtr getReturnType(const ColumnsWithTypeAndName &) const override { return {}; } // Not used
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return true; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0, 1}; }
private:
const Context & context;

View File

@ -41,6 +41,7 @@ namespace ErrorCodes
extern const int SYNTAX_ERROR;
extern const int SET_SIZE_LIMIT_EXCEEDED;
extern const int TYPE_MISMATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
@ -1109,27 +1110,34 @@ void HashJoin::joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed)
block = block.cloneWithColumns(std::move(dst_columns));
}
static void checkTypeOfKey(const Block & block_left, const Block & block_right)
{
const auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0);
const auto & [c2, right_type_origin, right_name] = block_right.safeGetByPosition(0);
auto left_type = removeNullable(left_type_origin);
auto right_type = removeNullable(right_type_origin);
if (!left_type->equals(*right_type))
throw Exception("Type mismatch of columns to joinGet by: "
+ left_name + " " + left_type->getName() + " at left, "
+ right_name + " " + right_type->getName() + " at right",
ErrorCodes::TYPE_MISMATCH);
}
DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null) const
DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const
{
std::shared_lock lock(data->rwlock);
size_t num_keys = data_types.size();
if (right_table_keys.columns() != num_keys)
throw Exception(
"Number of arguments for function joinGet" + toString(or_null ? "OrNull" : "")
+ " doesn't match: passed, should be equal to " + toString(num_keys),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < num_keys; ++i)
{
const auto & left_type_origin = data_types[i];
const auto & [c2, right_type_origin, right_name] = right_table_keys.safeGetByPosition(i);
auto left_type = removeNullable(left_type_origin);
auto right_type = removeNullable(right_type_origin);
if (!left_type->equals(*right_type))
throw Exception(
"Type mismatch in joinGet key " + toString(i) + ": found type " + left_type->getName() + ", while the needed type is "
+ right_type->getName(),
ErrorCodes::TYPE_MISMATCH);
}
if (!sample_block_with_columns_to_add.has(column_name))
throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::NO_SUCH_COLUMN_IN_TABLE);
auto elem = sample_block_with_columns_to_add.getByName(column_name);
if (or_null)
elem.type = makeNullable(elem.type);
@ -1138,34 +1146,33 @@ DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null
template <typename Maps>
void HashJoin::joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const
ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const
{
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::RightAny>(
block, {block.getByPosition(0).name}, block_with_columns_to_add, maps_);
// Assemble the key block with correct names.
Block keys;
for (size_t i = 0; i < block.columns(); ++i)
{
auto key = block.getByPosition(i);
key.name = key_names_right[i];
keys.insert(std::move(key));
}
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
keys, key_names_right, block_with_columns_to_add, maps_);
return keys.getByPosition(keys.columns() - 1);
}
// TODO: support composite key
// TODO: return multiple columns as named tuple
// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
void HashJoin::joinGet(Block & block, const String & column_name, bool or_null) const
ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const
{
std::shared_lock lock(data->rwlock);
if (key_names_right.size() != 1)
throw Exception("joinGet only supports StorageJoin containing exactly one key", ErrorCodes::UNSUPPORTED_JOIN_KEYS);
checkTypeOfKey(block, right_table_keys);
auto elem = sample_block_with_columns_to_add.getByName(column_name);
if (or_null)
elem.type = makeNullable(elem.type);
elem.column = elem.type->createColumn();
if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) &&
kind == ASTTableJoin::Kind::Left)
{
joinGetImpl(block, {elem}, std::get<MapsOne>(data->maps));
return joinGetImpl(block, block_with_columns_to_add, std::get<MapsOne>(data->maps));
}
else
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);

View File

@ -162,11 +162,11 @@ public:
*/
void joinBlock(Block & block, ExtraBlockPtr & not_processed) override;
/// Infer the return type for joinGet function
DataTypePtr joinGetReturnType(const String & column_name, bool or_null) const;
/// Check joinGet arguments and infer the return type.
DataTypePtr joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const;
/// Used by joinGet function that turns StorageJoin into a dictionary
void joinGet(Block & block, const String & column_name, bool or_null) const;
/// Used by joinGet function that turns StorageJoin into a dictionary.
ColumnWithTypeAndName joinGet(const Block & block, const Block & block_with_columns_to_add) const;
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
*/
@ -389,7 +389,7 @@ private:
void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const;
template <typename Maps>
void joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const;
ColumnWithTypeAndName joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const;
static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes);
};

View File

@ -28,7 +28,7 @@ inline bool functionIsLikeOperator(const std::string & name)
inline bool functionIsJoinGet(const std::string & name)
{
return name == "joinGet" || startsWith(name, "dictGet");
return startsWith(name, "joinGet");
}
inline bool functionIsDictGet(const std::string & name)

View File

@ -1,12 +1,12 @@
DROP TABLE IF EXISTS test_joinGet;
DROP TABLE IF EXISTS test_join_joinGet;
CREATE TABLE test_joinGet(id Int32, user_id Nullable(Int32)) Engine = Memory();
CREATE TABLE test_join_joinGet(user_id Int32, name String) Engine = Join(ANY, LEFT, user_id);
CREATE TABLE test_joinGet(user_id Nullable(Int32), name String) Engine = Join(ANY, LEFT, user_id);
INSERT INTO test_join_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c');
INSERT INTO test_joinGet VALUES (2, 'a'), (6, 'b'), (10, 'c'), (null, 'd');
SELECT 2 id, toNullable(toInt32(2)) user_id WHERE joinGet(test_join_joinGet, 'name', user_id) != '';
SELECT toNullable(toInt32(2)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != '';
-- If the JOIN keys are Nullable fields, the rows where at least one of the keys has the value NULL are not joined.
SELECT cast(null AS Nullable(Int32)) user_id WHERE joinGet(test_joinGet, 'name', user_id) != '';
DROP TABLE test_joinGet;
DROP TABLE test_join_joinGet;

View File

@ -0,0 +1 @@
0.1

View File

@ -0,0 +1,9 @@
DROP TABLE IF EXISTS test_joinGet;
CREATE TABLE test_joinGet(a String, b String, c Float64) ENGINE = Join(any, left, a, b);
INSERT INTO test_joinGet VALUES ('ab', '1', 0.1), ('ab', '2', 0.2), ('cd', '3', 0.3);
SELECT joinGet(test_joinGet, 'c', 'ab', '1');
DROP TABLE test_joinGet;