Merge pull request #41595 from Algunenano/fix_variadic_reading_random_data

Do not process rows in aggregations if any of the parameters is NULL
This commit is contained in:
Alexey Milovidov 2022-09-22 07:57:37 +03:00 committed by GitHub
commit bca4cc98c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 141 additions and 117 deletions

View File

@ -16,7 +16,7 @@ ClickHouse also supports:
## NULL Processing
During aggregation, all `NULL`s are skipped.
During aggregation, all `NULL`s are skipped. If the aggregation has several parameters it will ignore any row in which one or more of the parameters are NULL.
**Examples:**
@ -58,4 +58,17 @@ SELECT groupArray(y) FROM t_null_big
`groupArray` does not include `NULL` in the resulting array.
You can use [COALESCE](../../sql-reference/functions/functions-for-nulls.md#coalesce) to change NULL into a value that makes sense in your use case. For example: `avg(COALESCE(column, 0))` with use the column value in the aggregation or zero if NULL:
``` sql
SELECT
avg(y),
avg(coalesce(y, 0))
FROM t_null_big
```
``` text
┌─────────────avg(y)─┬─avg(coalesce(y, 0))─┐
│ 2.3333333333333335 │ 1.4 │
└────────────────────┴─────────────────────┘
```

View File

@ -218,10 +218,11 @@ public:
};
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
class AggregateFunctionIfNullVariadic final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>
template <bool result_is_nullable, bool serialize_flag>
class AggregateFunctionIfNullVariadic final : public AggregateFunctionNullBase<
result_is_nullable,
serialize_flag,
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag>>
{
public:
@ -259,7 +260,7 @@ public:
if (is_nullable[i])
{
const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);
if (null_is_skipped && nullable_col.isNullAt(row_num))
if (nullable_col.isNullAt(row_num))
{
/// If at least one column has a null value in the current row,
/// we don't process this row.
@ -293,7 +294,7 @@ public:
for (size_t i = row_begin; i < row_end; i++)
{
final_null_flags[i] = (null_is_skipped && filter_null_map[i]) || !filter_values[i];
final_null_flags[i] = filter_null_map[i] || !filter_values[i];
}
}
else
@ -310,7 +311,7 @@ public:
if (is_nullable[arg])
{
const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[arg]);
if (null_is_skipped && (arg != filter_column_num))
if (arg != filter_column_num)
{
const ColumnUInt8 & nullmap_column = nullable_col.getNullMapColumn();
const UInt8 * col_null_map = nullmap_column.getData().data();
@ -368,9 +369,7 @@ public:
if (is_nullable[i])
{
auto * wrapped_value = b.CreateExtractValue(argument_value, {0});
if constexpr (null_is_skipped)
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
wrapped_values[i] = wrapped_value;
non_nullable_types[i] = removeNullable(arguments_types[i]);
@ -387,23 +386,20 @@ public:
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * join_block_after_null_checks = llvm::BasicBlock::Create(head->getContext(), "join_block_after_null_checks", head->getParent());
if constexpr (null_is_skipped)
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
b.CreateStore(b.getInt1(false), values_have_null_ptr);
for (auto * is_null_value : is_null_values)
{
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
b.CreateStore(b.getInt1(false), values_have_null_ptr);
if (!is_null_value)
continue;
for (auto * is_null_value : is_null_values)
{
if (!is_null_value)
continue;
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
}
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), join_block, join_block_after_null_checks);
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
}
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), join_block, join_block_after_null_checks);
b.SetInsertPoint(join_block_after_null_checks);
const auto & predicate_type = arguments_types[argument_values.size() - 1];
@ -433,8 +429,10 @@ public:
#endif
private:
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>;
using Base = AggregateFunctionNullBase<
result_is_nullable,
serialize_flag,
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag>>;
static constexpr size_t MAX_ARGS = 8;
size_t number_of_arguments = 0;
@ -473,14 +471,14 @@ AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(
{
if (return_type_is_nullable)
{
return std::make_shared<AggregateFunctionIfNullVariadic<true, true, true>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionIfNullVariadic<true, true>>(nested_function, arguments, params);
}
else
{
if (need_to_serialize_flag)
return std::make_shared<AggregateFunctionIfNullVariadic<false, true, true>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionIfNullVariadic<false, true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionIfNullVariadic<false, false, true>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionIfNullVariadic<false, false>>(nested_function, arguments, params);
}
}
}

View File

@ -196,7 +196,7 @@ public:
const Array & params,
const AggregateFunctionProperties & /*properties*/) const override
{
return std::make_shared<AggregateFunctionNullVariadic<false, false, false>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionNullVariadic<false, false>>(nested_function, arguments, params);
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -108,14 +108,14 @@ public:
{
if (return_type_is_nullable)
{
return std::make_shared<AggregateFunctionNullVariadic<true, true, true>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionNullVariadic<true, true>>(nested_function, arguments, params);
}
else
{
if (serialize_flag)
return std::make_shared<AggregateFunctionNullVariadic<false, true, true>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionNullVariadic<false, true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullVariadic<false, true, false>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionNullVariadic<false, true>>(nested_function, arguments, params);
}
}
}

View File

@ -386,16 +386,17 @@ public:
};
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
class AggregateFunctionNullVariadic final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>
template <bool result_is_nullable, bool serialize_flag>
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<
result_is_nullable,
serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag>>
{
public:
AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>(std::move(nested_function_), arguments, params),
number_of_arguments(arguments.size())
: AggregateFunctionNullBase<result_is_nullable, serialize_flag, AggregateFunctionNullVariadic<result_is_nullable, serialize_flag>>(
std::move(nested_function_), arguments, params)
, number_of_arguments(arguments.size())
{
if (number_of_arguments == 1)
throw Exception("Logical error: single argument is passed to AggregateFunctionNullVariadic", ErrorCodes::LOGICAL_ERROR);
@ -418,7 +419,7 @@ public:
if (is_nullable[i])
{
const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);
if (null_is_skipped && nullable_col.isNullAt(row_num))
if (nullable_col.isNullAt(row_num))
{
/// If at least one column has a null value in the current row,
/// we don't process this row.
@ -476,11 +477,8 @@ public:
{
const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);
nested_columns[i] = &nullable_col.getNestedColumn();
if constexpr (null_is_skipped)
{
const ColumnUInt8 & nullmap_column = nullable_col.getNullMapColumn();
nullable_filters.push_back(nullmap_column.getData().data());
}
const ColumnUInt8 & nullmap_column = nullable_col.getNullMapColumn();
nullable_filters.push_back(nullmap_column.getData().data());
}
else
{
@ -488,14 +486,7 @@ public:
}
}
/// We can have 0 nullable filters if we don't skip nulls
if (nullable_filters.size() == 0)
{
this->setFlag(place);
this->nested_function->addBatchSinglePlace(row_begin, row_end, this->nestedPlace(place), nested_columns, arena, -1);
return;
}
chassert(nullable_filters.size() > 0);
bool found_one = false;
if (nullable_filters.size() == 1)
{
@ -567,9 +558,7 @@ public:
if (is_nullable[i])
{
auto * wrapped_value = b.CreateExtractValue(argument_value, {0});
if constexpr (null_is_skipped)
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
wrapped_values[i] = wrapped_value;
non_nullable_types[i] = removeNullable(arguments_types[i]);
@ -581,48 +570,39 @@ public:
}
}
if constexpr (null_is_skipped)
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
b.CreateStore(b.getInt1(false), values_have_null_ptr);
for (auto * is_null_value : is_null_values)
{
auto * head = b.GetInsertBlock();
if (!is_null_value)
continue;
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
b.CreateStore(b.getInt1(false), values_have_null_ptr);
for (auto * is_null_value : is_null_values)
{
if (!is_null_value)
continue;
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
}
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), if_null, if_not_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
if constexpr (result_is_nullable)
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, arguments_types, wrapped_values);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
}
else
{
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), if_null, if_not_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
if constexpr (result_is_nullable)
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, non_nullable_types, wrapped_values);
}
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, arguments_types, wrapped_values);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
#endif

View File

@ -199,17 +199,6 @@ public:
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
}
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
const AggregateFunctionProperties &) const override
{
/// Even though some values are mapped to aggregating key, it could return nulls for the below case.
/// aggregated events: [A -> B -> C]
/// events to find: [C -> D]
/// [C -> D] is not matched to 'A -> B -> C' so that it returns null.
return std::make_shared<AggregateFunctionNullVariadic<false, false, true>>(nested_function, arguments, params);
}
void insert(Data & a, const Node * v, Arena * arena) const
{
++a.total_values;

View File

@ -252,13 +252,6 @@ public:
bool allocatesMemoryInArena() const override { return false; }
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
const AggregateFunctionProperties & /*properties*/) const override
{
return std::make_shared<AggregateFunctionNullVariadic<false, false, false>>(nested_function, arguments, params);
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
bool has_event = false;

View File

@ -80,11 +80,11 @@ insert into funnel_test_non_null values (1, 1, 'a1', 'b1') (2, 1, 'a2', 'b2');
insert into funnel_test_non_null values (1, 2, 'a1', null) (2, 2, 'a2', null);
insert into funnel_test_non_null values (1, 3, null, null);
insert into funnel_test_non_null values (1, 4, null, 'b1') (2, 4, 'a2', null) (3, 4, null, 'b3');
select u, windowFunnel(86400)(dt, a = 'a1', a = 'a2') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400)(dt, a = 'a1', b = 'b2') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400)(dt, COALESCE(a, '') = 'a1', COALESCE(a, '') = 'a2') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400)(dt, COALESCE(a, '') = 'a1', COALESCE(b, '') = 'b2') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400)(dt, a is null and b is null) as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400)(dt, a is null, b = 'b3') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400, 'strict_order')(dt, a is null, b = 'b3') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400)(dt, a is null, COALESCE(b, '') = 'b3') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
select u, windowFunnel(86400, 'strict_order')(dt, a is null, COALESCE(b, '') = 'b3') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow;
drop table funnel_test_non_null;
create table funnel_test_strict_increase (timestamp UInt32, event UInt32) engine=Memory;

View File

@ -0,0 +1,31 @@
-- { echoOn }
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number) AS n from numbers(10));
(9,9) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number) AS n from numbers(10));
(NULL,NULL) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(10));
(8,8) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number::Int32) AS n from numbers(10));
(9,9) Tuple(Nullable(Int32), Nullable(Int32))
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number::Int32) AS n from numbers(10));
(NULL,NULL) Tuple(Nullable(Int32), Nullable(Int32))
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number::Int32) AS n from numbers(10));
(8,8) Tuple(Nullable(Int32), Nullable(Int32))
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number) AS n from numbers(5, 10));
(5,5) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number) AS n from numbers(5, 10));
(NULL,NULL) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 5 == 0, NULL, number) as n from numbers(5, 10));
(6,6) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number::Int32) AS n from numbers(5, 10));
(5,5) Tuple(Nullable(Int32), Nullable(Int32))
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number::Int32) AS n from numbers(5, 10));
(NULL,NULL) Tuple(Nullable(Int32), Nullable(Int32))
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 5 == 0, NULL, number::Int32) as n from numbers(5, 10));
(6,6) Tuple(Nullable(Int32), Nullable(Int32))
SELECT argMaxIf((n, n), n, n > 100) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(50));
(NULL,NULL) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMaxIf((n, n), n, n < 100) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(50));
(49,49) Tuple(Nullable(UInt64), Nullable(UInt64))
SELECT argMaxIf((n, n), n, n % 5 == 0) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(50));
(40,40) Tuple(Nullable(UInt64), Nullable(UInt64))

View File

@ -0,0 +1,20 @@
-- { echoOn }
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number) AS n from numbers(10));
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number) AS n from numbers(10));
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(10));
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number::Int32) AS n from numbers(10));
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number::Int32) AS n from numbers(10));
SELECT argMax((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number::Int32) AS n from numbers(10));
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number) AS n from numbers(5, 10));
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number) AS n from numbers(5, 10));
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 5 == 0, NULL, number) as n from numbers(5, 10));
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number >= 100, NULL, number::Int32) AS n from numbers(5, 10));
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number <= 100, NULL, number::Int32) AS n from numbers(5, 10));
SELECT argMin((n, n), n) t, toTypeName(t) FROM (SELECT if(number % 5 == 0, NULL, number::Int32) as n from numbers(5, 10));
SELECT argMaxIf((n, n), n, n > 100) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(50));
SELECT argMaxIf((n, n), n, n < 100) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(50));
SELECT argMaxIf((n, n), n, n % 5 == 0) t, toTypeName(t) FROM (SELECT if(number % 3 = 0, NULL, number) AS n from numbers(50));