Implement batch processing for aggregate functions with multiple nullable arguments (#41058)

* Implement batch processing for aggregate functions with multiple nullable arguments

* Fix broken perf test

* Improve filter handling in addBatchSinglePlace with nullable arguments

* Fix detecting the Null filter usage
This commit is contained in:
Raúl Marín 2022-09-15 23:51:38 +02:00 committed by GitHub
parent 6dac509739
commit c3ff66bd9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 290 additions and 0 deletions

View File

@ -278,6 +278,71 @@ public:
}
}
void addBatchSinglePlace(
size_t row_begin, size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, Arena * arena, ssize_t) const final
{
std::unique_ptr<UInt8[]> final_null_flags = std::make_unique<UInt8[]>(row_end);
const size_t filter_column_num = number_of_arguments - 1;
if (is_nullable[filter_column_num])
{
const ColumnNullable * nullable_column = assert_cast<const ColumnNullable *>(columns[filter_column_num]);
const IColumn & filter_column = nullable_column->getNestedColumn();
const UInt8 * filter_null_map = nullable_column->getNullMapColumn().getData().data();
const UInt8 * filter_values = assert_cast<const ColumnUInt8 &>(filter_column).getData().data();
for (size_t i = row_begin; i < row_end; i++)
{
final_null_flags[i] = (null_is_skipped && filter_null_map[i]) || !filter_values[i];
}
}
else
{
const IColumn * filter_column = columns[filter_column_num];
const UInt8 * filter_values = assert_cast<const ColumnUInt8 *>(filter_column)->getData().data();
for (size_t i = row_begin; i < row_end; i++)
final_null_flags[i] = !filter_values[i];
}
const IColumn * nested_columns[number_of_arguments];
for (size_t arg = 0; arg < number_of_arguments; arg++)
{
if (is_nullable[arg])
{
const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[arg]);
if (null_is_skipped && (arg != filter_column_num))
{
const ColumnUInt8 & nullmap_column = nullable_col.getNullMapColumn();
const UInt8 * col_null_map = nullmap_column.getData().data();
for (size_t r = row_begin; r < row_end; r++)
{
final_null_flags[r] |= col_null_map[r];
}
}
nested_columns[arg] = &nullable_col.getNestedColumn();
}
else
nested_columns[arg] = columns[arg];
}
bool at_least_one = false;
for (size_t i = row_begin; i < row_end; i++)
{
if (!final_null_flags[i])
{
at_least_one = true;
break;
}
}
if (at_least_one)
{
this->setFlag(place);
this->nested_function->addBatchSinglePlaceNotNull(
row_begin, row_end, this->nestedPlace(place), nested_columns, final_null_flags.get(), arena, -1);
}
}
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override

View File

@ -414,6 +414,109 @@ public:
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** columns,
Arena * arena,
ssize_t if_argument_pos) const final
{
/// We are going to merge all the flags into a single one to be able to call the nested batching functions
std::vector<const UInt8 *> nullable_filters;
const IColumn * nested_columns[number_of_arguments];
std::unique_ptr<UInt8[]> final_flags = nullptr;
const UInt8 * final_flags_ptr = nullptr;
if (if_argument_pos >= 0)
{
final_flags = std::make_unique<UInt8[]>(row_end);
final_flags_ptr = final_flags.get();
bool included_elements = 0;
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; i++)
{
final_flags[i] = !flags.data()[i];
included_elements += !!flags.data()[i];
}
if (included_elements == 0)
return;
if (included_elements != (row_end - row_begin))
{
nullable_filters.push_back(final_flags_ptr);
}
}
for (size_t i = 0; i < number_of_arguments; ++i)
{
if (is_nullable[i])
{
const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);
nested_columns[i] = &nullable_col.getNestedColumn();
if constexpr (null_is_skipped)
{
const ColumnUInt8 & nullmap_column = nullable_col.getNullMapColumn();
nullable_filters.push_back(nullmap_column.getData().data());
}
}
else
{
nested_columns[i] = columns[i];
}
}
bool found_one = false;
chassert(nullable_filters.size() > 0); /// We work under the assumption that we reach this because one argument was NULL
if (nullable_filters.size() == 1)
{
/// We can avoid making copies of the only filter but we still need to check that there is data to be added
final_flags_ptr = nullable_filters[0];
for (size_t i = row_begin; i < row_end; i++)
{
if (!final_flags_ptr[i])
{
found_one = true;
break;
}
}
}
else
{
if (!final_flags)
{
final_flags = std::make_unique<UInt8[]>(row_end);
final_flags_ptr = final_flags.get();
}
const size_t filter_start = nullable_filters[0] == final_flags_ptr ? 1 : 0;
for (size_t filter = filter_start; filter < nullable_filters.size(); filter++)
{
for (size_t i = row_begin; i < row_end; i++)
final_flags[i] |= nullable_filters[filter][i];
}
for (size_t i = row_begin; i < row_end; i++)
{
if (!final_flags_ptr[i])
{
found_one = true;
break;
}
}
}
if (!found_one)
return; // Nothing to do and nothing to mark
this->setFlag(place);
this->nested_function->addBatchSinglePlaceNotNull(
row_begin, row_end, this->nestedPlace(place), nested_columns, final_flags_ptr, arena, -1);
}
#if USE_EMBEDDED_COMPILER

View File

@ -32,5 +32,21 @@
<query>SELECT avgWeighted(num_u, num) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeighted(num_u, num_u) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeighted(num_f, num_f) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeighted(toNullable(num_f), num_f) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeighted(num_f, toNullable(num_f)) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeighted(toNullable(num_f), toNullable(num_f)) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(num_f, num_f, num % 10) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(toNullable(num_f), num_f, num % 10) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(num_f, toNullable(num_f), num % 10) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(toNullable(num_f), toNullable(num_f), num % 10) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(num_f, num_f, toNullable(num) % 10) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(toNullable(num_f), num_f, toNullable(num) % 10) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(num_f, toNullable(num_f), toNullable(num) % 10) FROM perf_avg FORMAT Null</query>
<query>SELECT avgWeightedIf(toNullable(num_f), toNullable(num_f), toNullable(num) % 10) FROM perf_avg FORMAT Null</query>
<drop_query>DROP TABLE IF EXISTS perf_avg</drop_query>
</test>

View File

@ -0,0 +1,65 @@
-- { echo }
SELECT avgWeighted(number, number) t, toTypeName(t) FROM numbers(1);
nan Float64
SELECT avgWeighted(number, number + 1) t, toTypeName(t) FROM numbers(0);
nan Float64
SELECT avgWeighted(toNullable(number), number) t, toTypeName(t) FROM numbers(1);
nan Nullable(Float64)
SELECT avgWeighted(if(number < 10000, NULL, number), number) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeighted(if(number < 50, NULL, number), number) t, toTypeName(t) FROM numbers(100);
77.29530201342281 Nullable(Float64)
SELECT avgWeighted(number, if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeighted(number, if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
77.29530201342281 Nullable(Float64)
SELECT avgWeighted(toNullable(number), if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeighted(toNullable(number), if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
77.29530201342281 Nullable(Float64)
SELECT avgWeighted(if(number < 10000, NULL, number), toNullable(number)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeighted(if(number < 50, NULL, number), toNullable(number)) t, toTypeName(t) FROM numbers(100);
77.29530201342281 Nullable(Float64)
SELECT avgWeighted(if(number < 10000, NULL, number), if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeighted(if(number < 50, NULL, number), if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeighted(if(number < 10000, NULL, number), if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeighted(if(number < 50, NULL, number), if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
77.29530201342281 Nullable(Float64)
SELECT avgWeightedIf(number, number, number % 10) t, toTypeName(t) FROM numbers(100);
66.63333333333334 Float64
SELECT avgWeightedIf(number, number, toNullable(number % 10)) t, toTypeName(t) FROM numbers(100);
66.63333333333334 Float64
SELECT avgWeightedIf(number, number, if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
nan Float64
SELECT avgWeightedIf(number, number, if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
77.75555555555556 Float64
SELECT avgWeightedIf(number, number, if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
66.63333333333334 Float64
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 10000, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 10000, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 50, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 50, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 10000, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 10000, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 50, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 50, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
77.75555555555556 Nullable(Float64)
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 10000, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 10000, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 50, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
\N Nullable(Float64)
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 50, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
77.75555555555556 Nullable(Float64)

View File

@ -0,0 +1,41 @@
-- { echo }
SELECT avgWeighted(number, number) t, toTypeName(t) FROM numbers(1);
SELECT avgWeighted(number, number + 1) t, toTypeName(t) FROM numbers(0);
SELECT avgWeighted(toNullable(number), number) t, toTypeName(t) FROM numbers(1);
SELECT avgWeighted(if(number < 10000, NULL, number), number) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(if(number < 50, NULL, number), number) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(number, if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(number, if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(toNullable(number), if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(toNullable(number), if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(if(number < 10000, NULL, number), toNullable(number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(if(number < 50, NULL, number), toNullable(number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(if(number < 10000, NULL, number), if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(if(number < 50, NULL, number), if(number < 10000, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(if(number < 10000, NULL, number), if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeighted(if(number < 50, NULL, number), if(number < 50, NULL, number)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(number, number, number % 10) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(number, number, toNullable(number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(number, number, if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(number, number, if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(number, number, if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 10000, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 10000, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 50, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 50, NULL, number), if(number < 10000, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 10000, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 10000, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 50, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 50, NULL, number), if(number < 50, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 10000, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 10000, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 10000, NULL, number), if(number < 50, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);
SELECT avgWeightedIf(if(number < 50, NULL, number), if(number < 50, NULL, number), if(number < 0, NULL, number % 10)) t, toTypeName(t) FROM numbers(100);