Merge pull request #72993 from ClickHouse/fix_parallel_hash_with_additional_filter

Fix `parallel_hash` join with additional filters in `ON` clause
This commit is contained in:
Nikita Taranov 2024-12-11 14:20:14 +00:00 committed by GitHub
commit 2bc96dc2da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 160 additions and 28 deletions

View File

@ -9,9 +9,6 @@
#include <Interpreters/TableJoin.h>
#include <Interpreters/castColumn.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
namespace DB
{
/// Inserting an element into a hash table of the form `key -> reference to a string`, which will then be used by JOIN.
@ -149,20 +146,22 @@ private:
template <bool need_filter>
static void setUsed(IColumn::Filter & filter [[maybe_unused]], size_t pos [[maybe_unused]]);
template <typename AddedColumns>
template <typename AddedColumns, typename Selector>
static ColumnPtr buildAdditionalFilter(
size_t left_start_row,
const Selector & selector,
const std::vector<const RowRef *> & selected_rows,
const std::vector<size_t> & row_replicate_offset,
AddedColumns & added_columns);
/// First to collect all matched rows refs by join keys, then filter out rows which are not true in additional filter expression.
template <typename KeyGetter, typename Map, typename AddedColumns>
template <typename KeyGetter, typename Map, typename AddedColumns, typename Selector>
static size_t joinRightColumnsWithAddtitionalFilter(
std::vector<KeyGetter> && key_getter_vector,
const std::vector<const Map *> & mapv,
AddedColumns & added_columns,
JoinStuff::JoinUsedFlags & used_flags [[maybe_unused]],
const Selector & selector,
bool need_filter [[maybe_unused]],
bool flag_per_row [[maybe_unused]]);

View File

@ -357,9 +357,15 @@ size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsSwitchMu
{
if (added_columns.additional_filter_expression)
{
bool mark_per_row_used = join_features.right || join_features.full || mapv.size() > 1;
const bool mark_per_row_used = join_features.right || join_features.full || mapv.size() > 1;
return joinRightColumnsWithAddtitionalFilter<KeyGetter, Map>(
std::forward<std::vector<KeyGetter>>(key_getter_vector), mapv, added_columns, used_flags, need_filter, mark_per_row_used);
std::forward<std::vector<KeyGetter>>(key_getter_vector),
mapv,
added_columns,
used_flags,
added_columns.src_block.getSelector(),
need_filter,
mark_per_row_used);
}
}
@ -548,9 +554,10 @@ void HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::setUsed(IColumn::Filter &
}
template <JoinKind KIND, JoinStrictness STRICTNESS, typename MapsTemplate>
template <typename AddedColumns>
template <typename AddedColumns, typename Selector>
ColumnPtr HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::buildAdditionalFilter(
size_t left_start_row,
const Selector & selector,
const std::vector<const RowRef *> & selected_rows,
const std::vector<size_t> & row_replicate_offset,
AddedColumns & added_columns)
@ -613,7 +620,7 @@ ColumnPtr HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::buildAdditionalFilter
const size_t & left_offset = row_replicate_offset[i];
size_t rows = left_offset - prev_left_offset;
if (rows)
new_col->insertManyFrom(*src_col->column, left_start_row + i - 1, rows);
new_col->insertManyFrom(*src_col->column, selector[left_start_row + i - 1], rows);
prev_left_offset = left_offset;
}
executed_block.insert({std::move(new_col), src_col->type, col_name});
@ -664,17 +671,18 @@ ColumnPtr HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::buildAdditionalFilter
}
template <JoinKind KIND, JoinStrictness STRICTNESS, typename MapsTemplate>
template <typename KeyGetter, typename Map, typename AddedColumns>
template <typename KeyGetter, typename Map, typename AddedColumns, typename Selector>
size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsWithAddtitionalFilter(
std::vector<KeyGetter> && key_getter_vector,
const std::vector<const Map *> & mapv,
AddedColumns & added_columns,
JoinStuff::JoinUsedFlags & used_flags [[maybe_unused]],
const Selector & selector,
bool need_filter [[maybe_unused]],
bool flag_per_row [[maybe_unused]])
{
constexpr JoinFeatures<KIND, STRICTNESS, MapsTemplate> join_features;
size_t left_block_rows = added_columns.rows_to_add;
const size_t left_block_rows = added_columns.src_block.rows();
if (need_filter)
added_columns.filter = IColumn::Filter(left_block_rows, 0);
@ -688,7 +696,7 @@ size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsWithAddt
using FindResult = typename KeyGetter::FindResult;
size_t max_joined_block_rows = added_columns.max_joined_block_rows;
size_t left_row_iter = 0;
size_t it = 0;
PreSelectedRows selected_rows;
selected_rows.reserve(left_block_rows);
std::vector<FindResult> find_results;
@ -705,8 +713,10 @@ size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsWithAddt
row_replicate_offset.push_back(0);
current_added_rows = 0;
selected_rows.clear();
for (; left_row_iter < left_block_rows; ++left_row_iter)
for (; it < left_block_rows; ++it)
{
size_t ind = selector[it];
if constexpr (join_features.need_replication)
{
if (unlikely(total_added_rows + current_added_rows >= max_joined_block_rows))
@ -719,13 +729,12 @@ size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsWithAddt
for (size_t join_clause_idx = 0; join_clause_idx < added_columns.join_on_keys.size(); ++join_clause_idx)
{
const auto & join_keys = added_columns.join_on_keys[join_clause_idx];
if (join_keys.null_map && (*join_keys.null_map)[left_row_iter])
if (join_keys.null_map && (*join_keys.null_map)[ind])
continue;
bool row_acceptable = !join_keys.isRowFiltered(left_row_iter);
auto find_result = row_acceptable
? key_getter_vector[join_clause_idx].findKey(*(mapv[join_clause_idx]), left_row_iter, *pool)
: FindResult();
bool row_acceptable = !join_keys.isRowFiltered(ind);
auto find_result
= row_acceptable ? key_getter_vector[join_clause_idx].findKey(*(mapv[join_clause_idx]), ind, *pool) : FindResult();
if (find_result.isFound())
{
@ -878,11 +887,11 @@ size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsWithAddt
}
};
while (left_row_iter < left_block_rows && !exceeded_max_block_rows)
while (it < left_block_rows && !exceeded_max_block_rows)
{
auto left_start_row = left_row_iter;
auto left_start_row = it;
collect_keys_matched_rows_refs();
if (selected_rows.size() != current_added_rows || row_replicate_offset.size() != left_row_iter - left_start_row + 1)
if (selected_rows.size() != current_added_rows || row_replicate_offset.size() != it - left_start_row + 1)
{
throw Exception(
ErrorCodes::LOGICAL_ERROR,
@ -891,10 +900,10 @@ size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsWithAddt
selected_rows.size(),
current_added_rows,
row_replicate_offset.size(),
left_row_iter,
it,
left_start_row);
}
auto filter_col = buildAdditionalFilter(left_start_row, selected_rows, row_replicate_offset, added_columns);
auto filter_col = buildAdditionalFilter(left_start_row, selector, selected_rows, row_replicate_offset, added_columns);
copy_final_matched_rows(left_start_row, filter_col);
if constexpr (join_features.need_replication)
@ -907,11 +916,11 @@ size_t HashJoinMethods<KIND, STRICTNESS, MapsTemplate>::joinRightColumnsWithAddt
if constexpr (join_features.need_replication)
{
added_columns.offsets_to_replicate->resize_assume_reserved(left_row_iter);
added_columns.filter.resize_assume_reserved(left_row_iter);
added_columns.offsets_to_replicate->resize_assume_reserved(it);
added_columns.filter.resize_assume_reserved(it);
}
added_columns.applyLazyDefaults();
return left_row_iter;
return it;
}
template <JoinKind KIND, JoinStrictness STRICTNESS, typename MapsTemplate>

View File

@ -876,6 +876,7 @@ std::shared_ptr<IJoin> chooseJoinAlgorithm(
{
if (table_join->getMixedJoinExpression()
&& !table_join->isEnabledAlgorithm(JoinAlgorithm::HASH)
&& !table_join->isEnabledAlgorithm(JoinAlgorithm::PARALLEL_HASH)
&& !table_join->isEnabledAlgorithm(JoinAlgorithm::GRACE_HASH))
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED,

View File

@ -735,6 +735,24 @@ key1 b 2 3 2 key1 C 3 4 5
key1 c 3 2 1 key1 D 4 1 6
SELECT * FROM (SELECT 1 AS a, 1 AS b, 1 AS c) AS t1 INNER ANY JOIN (SELECT 1 AS a, 1 AS b, 1 AS c) AS t2 ON t1.a = t2.a AND (t1.b > 0 OR t2.b > 0);
1 1 1 1 1 1
SET join_algorithm='parallel_hash';
SELECT t1.*, t2.* FROM t1 INNER ANY JOIN t2 ON (t1.a < t2.a OR lower(t1.attr) == lower(t2.attr)) AND t1.key = t2.key ORDER BY (t1.key, t1.attr, t2.key, t2.attr);
key1 a 1 1 2 key1 A 1 2 1
key1 b 2 3 2 key1 B 2 1 2
key1 c 3 2 1 key1 C 3 4 5
key1 d 4 7 2 key1 D 4 1 6
key4 f 2 3 4 key4 F 1 1 1
SELECT t1.*, t2.* from t1 INNER ANY JOIN t2 ON t1.key = t2.key and (t1.b + t2.b == t1.c + t2.c) ORDER BY (t1.key, t1.attr, t2.key, t2.attr);
key1 a 1 1 2 key1 A 1 2 1
key1 b 2 3 2 key1 B 2 1 2
key1 c 3 2 1 key1 B 2 1 2
key1 d 4 7 2 key1 D 4 1 6
SELECT t1.*, t2.* from t1 INNER ANY JOIN t2 ON t1.key = t2.key and (t1.a < t2.a) ORDER BY (t1.key, t1.attr, t2.key, t2.attr);
key1 a 1 1 2 key1 B 2 1 2
key1 b 2 3 2 key1 C 3 4 5
key1 c 3 2 1 key1 D 4 1 6
SELECT * FROM (SELECT 1 AS a, 1 AS b, 1 AS c) AS t1 INNER ANY JOIN (SELECT 1 AS a, 1 AS b, 1 AS c) AS t2 ON t1.a = t2.a AND (t1.b > 0 OR t2.b > 0);
1 1 1 1 1 1
SET join_algorithm='hash';
SELECT t1.* FROM t1 INNER ANY JOIN t2 ON t1.key = t2.key AND t1.a < t2.a OR t1.a = t2.a ORDER BY ALL;
key1 a 1 1 2

View File

@ -50,7 +50,7 @@ SELECT t1.*, t2.* FROM t1 {{ join_type }} JOIN t2 ON t1.key = t2.key AND t1.a <
{% endfor -%}
{% endfor -%}
{% for algorithm in ['hash', 'grace_hash'] -%}
{% for algorithm in ['hash', 'grace_hash', 'parallel_hash'] -%}
SET join_algorithm='{{ algorithm }}';
{% for join_type in ['INNER'] -%}
{% for join_strictness in ['ANY'] -%}
@ -74,7 +74,7 @@ SELECT t1.* FROM t1 {{ join_type }} {{ join_strictness }} JOIN t2 ON t1.key = t2
-- { echoOff }
-- test error messages
{% for algorithm in ['partial_merge', 'full_sorting_merge', 'parallel_hash', 'auto', 'direct'] -%}
{% for algorithm in ['partial_merge', 'full_sorting_merge', 'auto', 'direct'] -%}
SET join_algorithm='{{ algorithm }}';
{% for join_type in ['LEFT', 'RIGHT', 'FULL'] -%}
SELECT t1.*, t2.* FROM t1 {{ join_type }} JOIN t2 ON (t1.a < t2.a OR lower(t1.attr) == lower(t2.attr)) AND t1.key = t2.key ORDER BY (t1.key, t1.attr, t2.key, t2.attr); -- { serverError NOT_IMPLEMENTED }

View File

@ -0,0 +1,8 @@
---- HASH
1 10 alpha 1 5 ALPHA
2 15 beta 2 10 beta
3 20 gamma 0 0
---- PARALLEL HASH
1 10 alpha 1 5 ALPHA
2 15 beta 2 10 beta
3 20 gamma 0 0

View File

@ -0,0 +1,24 @@
CREATE TABLE t1 (
key UInt32,
a UInt32,
attr String
) ENGINE = MergeTree ORDER BY key;
CREATE TABLE t2 (
key UInt32,
a UInt32,
attr String
) ENGINE = MergeTree ORDER BY key;
INSERT INTO t1 (key, a, attr) VALUES (1, 10, 'alpha'), (2, 15, 'beta'), (3, 20, 'gamma');
INSERT INTO t2 (key, a, attr) VALUES (1, 5, 'ALPHA'), (2, 10, 'beta'), (4, 25, 'delta');
SET allow_experimental_join_condition = 1;
SET enable_analyzer = 1;
SET max_threads = 16;
SELECT '---- HASH';
SELECT t1.*, t2.* FROM t1 LEFT JOIN t2 ON t1.key = t2.key AND (t1.key < t2.a OR t1.a % 2 = 0) ORDER BY ALL SETTINGS join_algorithm = 'hash';
SELECT '---- PARALLEL HASH';
SELECT t1.*, t2.* FROM t1 LEFT JOIN t2 ON t1.key = t2.key AND (t1.key < t2.a OR t1.a % 2 = 0) ORDER BY ALL SETTINGS join_algorithm = 'parallel_hash';

View File

@ -0,0 +1,73 @@
#!/usr/bin/env bash
set -e
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CUR_DIR"/../shell_config.sh
ROWS=123456
SEED=$(${CLICKHOUSE_CLIENT} -q "SELECT reinterpretAsUInt32(today())")
${CLICKHOUSE_CLIENT} --max_threads 16 --query="
CREATE TABLE t1 ENGINE = MergeTree ORDER BY tuple() AS
SELECT
sipHash64(CounterID, $SEED) AS CounterID,
EventDate,
sipHash64(WatchID, $SEED) AS WatchID,
sipHash64(UserID, $SEED) AS UserID,
URL
FROM test.hits
ORDER BY
CounterID ASC,
EventDate ASC
LIMIT $ROWS;
CREATE TABLE t2 ENGINE = MergeTree ORDER BY tuple() AS
SELECT
sipHash64(CounterID, $SEED) AS CounterID,
EventDate,
sipHash64(WatchID, $SEED) AS WatchID,
sipHash64(UserID, $SEED) AS UserID,
URL
FROM test.hits
ORDER BY
CounterID DESC,
EventDate DESC
LIMIT $ROWS;
set max_memory_usage = 0;
CREATE TABLE res_hash
ENGINE = MergeTree()
ORDER BY (CounterID, EventDate, WatchID, UserID, URL, t2.CounterID, t2.EventDate, t2.WatchID, t2.UserID, t2.URL)
AS SELECT
t1.*,
t2.*
FROM t1
LEFT JOIN t2 ON (t1.UserID = t2.UserID) AND ((t1.EventDate < t2.EventDate) OR (length(t1.URL) > length(t2.URL)))
ORDER BY ALL
LIMIT $ROWS
SETTINGS join_algorithm = 'hash';
CREATE TABLE res_parallel_hash
ENGINE = MergeTree()
ORDER BY (CounterID, EventDate, WatchID, UserID, URL, t2.CounterID, t2.EventDate, t2.WatchID, t2.UserID, t2.URL)
AS SELECT
t1.*,
t2.*
FROM t1
LEFT JOIN t2 ON (t1.UserID = t2.UserID) AND ((t1.EventDate < t2.EventDate) OR (length(t1.URL) > length(t2.URL)))
ORDER BY ALL
LIMIT $ROWS
SETTINGS join_algorithm = 'parallel_hash';
SELECT *
FROM (
SELECT * FROM res_hash ORDER BY ALL
EXCEPT
SELECT * FROM res_parallel_hash ORDER BY ALL
)
LIMIT 1;
"