diff --git a/src/Columns/ColumnAggregateFunction.cpp b/src/Columns/ColumnAggregateFunction.cpp index e26fe790a8e..4bc48c62eb4 100644 --- a/src/Columns/ColumnAggregateFunction.cpp +++ b/src/Columns/ColumnAggregateFunction.cpp @@ -330,7 +330,38 @@ ColumnPtr ColumnAggregateFunction::filter(const Filter & filter, ssize_t result_ void ColumnAggregateFunction::expand(const Filter & mask, bool inverted) { - expandDataByMask(data, mask, inverted); + ensureOwnership(); + Arena & arena = createOrGetArena(); + + if (mask.size() < data.size()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Mask size should be no less than data size."); + + ssize_t from = data.size() - 1; + ssize_t index = mask.size() - 1; + data.resize(mask.size()); + while (index >= 0) + { + if (!!mask[index] ^ inverted) + { + if (from < 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Too many bytes in mask"); + + /// Copy only if it makes sense. + if (index != from) + data[index] = data[from]; + --from; + } + else + { + data[index] = arena.alignedAlloc(func->sizeOfData(), func->alignOfData()); + func->create(data[index]); + } + + --index; + } + + if (from != -1) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Not enough bytes in mask"); } ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limit) const diff --git a/tests/queries/0_stateless/03210_variant_with_aggregate_function_type.reference b/tests/queries/0_stateless/03210_variant_with_aggregate_function_type.reference new file mode 100644 index 00000000000..105e8e7d8bd --- /dev/null +++ b/tests/queries/0_stateless/03210_variant_with_aggregate_function_type.reference @@ -0,0 +1,6 @@ + 500 +fail 500 + 499 +fail 500 + 500 499 +fail 500 500 diff --git a/tests/queries/0_stateless/03210_variant_with_aggregate_function_type.sql b/tests/queries/0_stateless/03210_variant_with_aggregate_function_type.sql new file mode 100644 index 00000000000..cb9cdb0b456 --- /dev/null +++ b/tests/queries/0_stateless/03210_variant_with_aggregate_function_type.sql @@ -0,0 +1,60 @@ +SET allow_experimental_variant_type = 1; + +DROP TABLE IF EXISTS source; +CREATE TABLE source +( + Name String, + Value Int64 + +) ENGINE = MergeTree ORDER BY (); + +INSERT INTO source SELECT ['fail', 'success'][number % 2] as Name, number AS Value FROM numbers(1000); + +DROP TABLE IF EXISTS test_agg_variant; +CREATE TABLE test_agg_variant +( + Name String, + Value Variant(AggregateFunction(uniqExact, Int64), AggregateFunction(avg, Int64)) +) +ENGINE = MergeTree +ORDER BY (Name); + +INSERT INTO test_agg_variant +SELECT + Name, + t AS Value +FROM +( + SELECT + Name, + arrayJoin([ + uniqExactState(Value)::Variant(AggregateFunction(uniqExact, Int64), AggregateFunction(avg, Int64)), + avgState(Value)::Variant(AggregateFunction(uniqExact, Int64), AggregateFunction(avg, Int64)) + ]) AS t + FROM source + GROUP BY Name +); + +SELECT + Name, + uniqExactMerge(Value.`AggregateFunction(uniqExact, Int64)`) AS Value +FROM test_agg_variant +GROUP BY Name; + +SELECT + Name, + avgMerge(Value.`AggregateFunction(avg, Int64)`) AS Value +FROM test_agg_variant +GROUP BY Name; + +SELECT + Name, + uniqExactMerge(Value.`AggregateFunction(uniqExact, Int64)`) AS ValueUniq, + avgMerge(Value.`AggregateFunction(avg, Int64)`) AS ValueAvg +FROM test_agg_variant +GROUP BY Name; + + +DROP TABLE test_agg_variant; +DROP TABLE source; +