Perserve constness in ExpressionActionsChain::JoinStep

This commit is contained in:
vdimir 2022-05-19 18:47:26 +00:00
parent 2995b69f4a
commit 23a85d3406
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862
6 changed files with 63 additions and 16 deletions

View File

@ -1034,7 +1034,7 @@ void ExpressionActionsChain::ArrayJoinStep::finalize(const NameSet & required_ou
ExpressionActionsChain::JoinStep::JoinStep(
std::shared_ptr<TableJoin> analyzed_join_,
JoinPtr join_,
ColumnsWithTypeAndName required_columns_)
const ColumnsWithTypeAndName & required_columns_)
: Step({})
, analyzed_join(std::move(analyzed_join_))
, join(std::move(join_))
@ -1042,11 +1042,8 @@ ExpressionActionsChain::JoinStep::JoinStep(
for (const auto & column : required_columns_)
required_columns.emplace_back(column.name, column.type);
NamesAndTypesList result_names_and_types = required_columns;
analyzed_join->addJoinedColumnsAndCorrectTypes(result_names_and_types, true);
for (const auto & [name, type] : result_names_and_types)
/// `column` is `nullptr` because we don't care on constness here, it may be changed in join
result_columns.emplace_back(nullptr, type, name);
result_columns = required_columns_;
analyzed_join->addJoinedColumnsAndCorrectTypes(result_columns, true);
}
void ExpressionActionsChain::JoinStep::finalize(const NameSet & required_output_)

View File

@ -233,7 +233,7 @@ struct ExpressionActionsChain : WithContext
NamesAndTypesList required_columns;
ColumnsWithTypeAndName result_columns;
JoinStep(std::shared_ptr<TableJoin> analyzed_join_, JoinPtr join_, ColumnsWithTypeAndName required_columns_);
JoinStep(std::shared_ptr<TableJoin> analyzed_join_, JoinPtr join_, const ColumnsWithTypeAndName & required_columns_);
NamesAndTypesList getRequiredColumns() const override { return required_columns; }
ColumnsWithTypeAndName getResultColumns() const override { return result_columns; }
void finalize(const NameSet & required_output_) override;

View File

@ -27,6 +27,7 @@
#include <Common/logger_useful.h>
#include <algorithm>
#include <string>
#include <type_traits>
#include <vector>
@ -328,6 +329,21 @@ NamesAndTypesList TableJoin::correctedColumnsAddedByJoin() const
void TableJoin::addJoinedColumnsAndCorrectTypes(NamesAndTypesList & left_columns, bool correct_nullability)
{
addJoinedColumnsAndCorrectTypesImpl(left_columns, correct_nullability);
}
void TableJoin::addJoinedColumnsAndCorrectTypes(ColumnsWithTypeAndName & left_columns, bool correct_nullability)
{
addJoinedColumnsAndCorrectTypesImpl(left_columns, correct_nullability);
}
template <typename TColumns>
void TableJoin::addJoinedColumnsAndCorrectTypesImpl(TColumns & left_columns, bool correct_nullability)
{
static_assert(std::is_same_v<typename TColumns::value_type, ColumnWithTypeAndName> ||
std::is_same_v<typename TColumns::value_type, NameAndTypePair>);
constexpr bool has_column = std::is_same_v<typename TColumns::value_type, ColumnWithTypeAndName>;
for (auto & col : left_columns)
{
if (hasUsing())
@ -342,15 +358,26 @@ void TableJoin::addJoinedColumnsAndCorrectTypes(NamesAndTypesList & left_columns
inferJoinKeyCommonType(left_columns, columns_from_joined_table, !isSpecialStorage());
if (auto it = left_type_map.find(col.name); it != left_type_map.end())
{
col.type = it->second;
if constexpr (has_column)
col.column = nullptr;
}
}
if (correct_nullability && leftBecomeNullable(col.type))
{
col.type = JoinCommon::convertTypeToNullable(col.type);
if constexpr (has_column)
col.column = nullptr;
}
}
for (const auto & col : correctedColumnsAddedByJoin())
left_columns.emplace_back(col.name, col.type);
if constexpr (has_column)
left_columns.emplace_back(nullptr, col.type, col.name);
else
left_columns.emplace_back(col.name, col.type);
}
bool TableJoin::sameStrictnessAndKind(ASTTableJoin::Strictness strictness_, ASTTableJoin::Kind kind_) const

View File

@ -254,7 +254,11 @@ public:
bool rightBecomeNullable(const DataTypePtr & column_type) const;
void addJoinedColumn(const NameAndTypePair & joined_column);
template <typename TColumns>
void addJoinedColumnsAndCorrectTypesImpl(TColumns & left_columns, bool correct_nullability);
void addJoinedColumnsAndCorrectTypes(NamesAndTypesList & left_columns, bool correct_nullability);
void addJoinedColumnsAndCorrectTypes(ColumnsWithTypeAndName & left_columns, bool correct_nullability);
/// Calculate converting actions, rename key columns in required
/// For `USING` join we will convert key columns inplace and affect into types in the result table

View File

@ -1 +1,7 @@
0
1970-01-01 00:00:00
0
2020-01-01 00:00:00
1

View File

@ -1,15 +1,28 @@
DROP TABLE IF EXISTS e;
-- https://github.com/ClickHouse/ClickHouse/issues/36891
CREATE TABLE e ( a UInt64, t DateTime ) ENGINE = MergeTree PARTITION BY toDate(t) ORDER BY tuple();
INSERT INTO e SELECT 1, toDateTime('2020-02-01 12:00:01') + INTERVAL number MONTH FROM numbers(10);
SELECT any('1')
FROM e JOIN ( SELECT 1 :: UInt32 AS key) AS da ON key = a
PREWHERE toString(a) = '1';
SELECT sumIf( 1, if( 1, toDateTime('2020-01-01 00:00:00', 'UTC'), toDateTime('1970-01-01 00:00:00', 'UTC')) > t )
FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a
WHERE t >= toDateTime('2021-07-19T13:00:00', 'UTC') AND t <= toDateTime('2021-07-19T13:59:59', 'UTC');
SELECT any( toDateTime('2020-01-01T00:00:00', 'UTC'))
FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a
PREWHERE t >= toDateTime('2021-07-19T13:00:00', 'UTC');
SELECT sumIf( 1, if( 1, toDateTime('2020-01-01 00:00:00', 'UTC'), toDateTime('1970-01-01 00:00:00', 'UTC')) > t )
FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a
WHERE t >= toDateTime('2020-01-01 00:00:00', 'UTC') AND t <= toDateTime('2021-07-19T13:59:59', 'UTC');
SELECT any(toDateTime('2020-01-01 00:00:00'))
FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a
PREWHERE t >= toDateTime('2020-01-01 00:00:00');
SELECT any('2020-01-01 00:00:00') FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a PREWHERE t = '2020-01-01 00:00:00';
SELECT any('x') FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a PREWHERE toString(a) = 'x';
SELECT any('1') FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a PREWHERE toString(a) = '1';
-- SELECT sumIf( 1, if( 1, toDateTime('2020-01-01 00:00:00', 'UTC'), toDateTime('1970-01-01 00:00:00', 'UTC')) > t )
-- FROM e JOIN ( SELECT 1 joinKey) AS da ON joinKey = a
-- WHERE t >= toDateTime('2021-07-19T13:00:00', 'UTC') AND t <= toDateTime('2021-07-19T13:59:59', 'UTC');