diff --git a/src/AggregateFunctions/AggregateFunctionArray.cpp b/src/AggregateFunctions/AggregateFunctionArray.cpp index 5ec41fbdd82..982180ab50c 100644 --- a/src/AggregateFunctions/AggregateFunctionArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionArray.cpp @@ -43,9 +43,9 @@ public: const AggregateFunctionPtr & nested_function, const AggregateFunctionProperties &, const DataTypes & arguments, - const Array &) const override + const Array & params) const override { - return std::make_shared(nested_function, arguments); + return std::make_shared(nested_function, arguments, params); } }; diff --git a/src/AggregateFunctions/AggregateFunctionArray.h b/src/AggregateFunctions/AggregateFunctionArray.h index f1005e2e43a..e6f2b46c67e 100644 --- a/src/AggregateFunctions/AggregateFunctionArray.h +++ b/src/AggregateFunctions/AggregateFunctionArray.h @@ -29,10 +29,11 @@ private: size_t num_arguments; public: - AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments) - : IAggregateFunctionHelper(arguments, {}) + AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_) + : IAggregateFunctionHelper(arguments, params_) , nested_func(nested_), num_arguments(arguments.size()) { + assert(parameters == nested_func->getParameters()); for (const auto & type : arguments) if (!isArray(type)) throw Exception("All arguments for aggregate function " + getName() + " must be arrays", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/src/AggregateFunctions/AggregateFunctionDistinct.cpp b/src/AggregateFunctions/AggregateFunctionDistinct.cpp index d5e4d421bb1..f224768991b 100644 --- a/src/AggregateFunctions/AggregateFunctionDistinct.cpp +++ b/src/AggregateFunctions/AggregateFunctionDistinct.cpp @@ -34,14 +34,14 @@ public: const AggregateFunctionPtr & nested_function, const AggregateFunctionProperties &, const DataTypes & arguments, - const Array &) const override + const Array & params) const override { AggregateFunctionPtr res; if (arguments.size() == 1) { res.reset(createWithNumericType< AggregateFunctionDistinct, - AggregateFunctionDistinctSingleNumericData>(*arguments[0], nested_function, arguments)); + AggregateFunctionDistinctSingleNumericData>(*arguments[0], nested_function, arguments, params)); if (res) return res; @@ -49,14 +49,14 @@ public: if (arguments[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) return std::make_shared< AggregateFunctionDistinct< - AggregateFunctionDistinctSingleGenericData>>(nested_function, arguments); + AggregateFunctionDistinctSingleGenericData>>(nested_function, arguments, params); else return std::make_shared< AggregateFunctionDistinct< - AggregateFunctionDistinctSingleGenericData>>(nested_function, arguments); + AggregateFunctionDistinctSingleGenericData>>(nested_function, arguments, params); } - return std::make_shared>(nested_function, arguments); + return std::make_shared>(nested_function, arguments, params); } }; diff --git a/src/AggregateFunctions/AggregateFunctionDistinct.h b/src/AggregateFunctions/AggregateFunctionDistinct.h index 9b7853f8665..0f085423bb9 100644 --- a/src/AggregateFunctions/AggregateFunctionDistinct.h +++ b/src/AggregateFunctions/AggregateFunctionDistinct.h @@ -167,8 +167,8 @@ private: } public: - AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes & arguments) - : IAggregateFunctionDataHelper(arguments, nested_func_->getParameters()) + AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes & arguments, const Array & params_) + : IAggregateFunctionDataHelper(arguments, params_) , nested_func(nested_func_) , arguments_num(arguments.size()) {} diff --git a/src/AggregateFunctions/AggregateFunctionForEach.cpp b/src/AggregateFunctions/AggregateFunctionForEach.cpp index 7b09c7d95da..cf448d602bf 100644 --- a/src/AggregateFunctions/AggregateFunctionForEach.cpp +++ b/src/AggregateFunctions/AggregateFunctionForEach.cpp @@ -38,9 +38,9 @@ public: const AggregateFunctionPtr & nested_function, const AggregateFunctionProperties &, const DataTypes & arguments, - const Array &) const override + const Array & params) const override { - return std::make_shared(nested_function, arguments); + return std::make_shared(nested_function, arguments, params); } }; diff --git a/src/AggregateFunctions/AggregateFunctionForEach.h b/src/AggregateFunctions/AggregateFunctionForEach.h index 66209d8c0f5..084396b2405 100644 --- a/src/AggregateFunctions/AggregateFunctionForEach.h +++ b/src/AggregateFunctions/AggregateFunctionForEach.h @@ -105,8 +105,8 @@ private: } public: - AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments) - : IAggregateFunctionDataHelper(arguments, {}) + AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_) + : IAggregateFunctionDataHelper(arguments, params_) , nested_func(nested_), num_arguments(arguments.size()) { nested_size_of_data = nested_func->sizeOfData(); diff --git a/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp b/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp index 646d0341343..7709357189c 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp @@ -25,8 +25,8 @@ template class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArray { public: - explicit AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type, UInt64 max_elems_ = std::numeric_limits::max()) - : AggregateFunctionGroupUniqArray(argument_type, max_elems_) {} + explicit AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits::max()) + : AggregateFunctionGroupUniqArray(argument_type, parameters_, max_elems_) {} DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared()); } }; @@ -34,8 +34,8 @@ template class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUniqArray { public: - explicit AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type, UInt64 max_elems_ = std::numeric_limits::max()) - : AggregateFunctionGroupUniqArray(argument_type, max_elems_) {} + explicit AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits::max()) + : AggregateFunctionGroupUniqArray(argument_type, parameters_, max_elems_) {} DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared()); } }; @@ -102,9 +102,9 @@ AggregateFunctionPtr createAggregateFunctionGroupUniqArray( ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!limit_size) - return createAggregateFunctionGroupUniqArrayImpl(name, argument_types[0]); + return createAggregateFunctionGroupUniqArrayImpl(name, argument_types[0], parameters); else - return createAggregateFunctionGroupUniqArrayImpl(name, argument_types[0], max_elems); + return createAggregateFunctionGroupUniqArrayImpl(name, argument_types[0], parameters, max_elems); } } diff --git a/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h b/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h index ccba789483f..cec160ee21f 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h +++ b/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h @@ -48,9 +48,9 @@ private: using State = AggregateFunctionGroupUniqArrayData; public: - AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, UInt64 max_elems_ = std::numeric_limits::max()) + AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits::max()) : IAggregateFunctionDataHelper, - AggregateFunctionGroupUniqArray>({argument_type}, {}), + AggregateFunctionGroupUniqArray>({argument_type}, parameters_), max_elems(max_elems_) {} String getName() const override { return "groupUniqArray"; } @@ -152,8 +152,8 @@ class AggregateFunctionGroupUniqArrayGeneric using State = AggregateFunctionGroupUniqArrayGenericData; public: - AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, UInt64 max_elems_ = std::numeric_limits::max()) - : IAggregateFunctionDataHelper>({input_data_type_}, {}) + AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits::max()) + : IAggregateFunctionDataHelper>({input_data_type_}, parameters_) , input_data_type(this->argument_types[0]) , max_elems(max_elems_) {} diff --git a/src/AggregateFunctions/AggregateFunctionIf.cpp b/src/AggregateFunctions/AggregateFunctionIf.cpp index c074daf45be..d841fe8c06d 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.cpp +++ b/src/AggregateFunctions/AggregateFunctionIf.cpp @@ -35,9 +35,9 @@ public: const AggregateFunctionPtr & nested_function, const AggregateFunctionProperties &, const DataTypes & arguments, - const Array &) const override + const Array & params) const override { - return std::make_shared(nested_function, arguments); + return std::make_shared(nested_function, arguments, params); } }; diff --git a/src/AggregateFunctions/AggregateFunctionIf.h b/src/AggregateFunctions/AggregateFunctionIf.h index 153c80e87b2..79999437ca1 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.h +++ b/src/AggregateFunctions/AggregateFunctionIf.h @@ -37,8 +37,8 @@ private: size_t num_arguments; public: - AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types) - : IAggregateFunctionHelper(types, nested->getParameters()) + AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types, const Array & params_) + : IAggregateFunctionHelper(types, params_) , nested_func(nested), num_arguments(types.size()) { if (num_arguments == 0) diff --git a/src/AggregateFunctions/AggregateFunctionMerge.cpp b/src/AggregateFunctions/AggregateFunctionMerge.cpp index a19a21fd4a4..cdf399585f5 100644 --- a/src/AggregateFunctions/AggregateFunctionMerge.cpp +++ b/src/AggregateFunctions/AggregateFunctionMerge.cpp @@ -39,7 +39,7 @@ public: const AggregateFunctionPtr & nested_function, const AggregateFunctionProperties &, const DataTypes & arguments, - const Array &) const override + const Array & params) const override { const DataTypePtr & argument = arguments[0]; @@ -53,7 +53,7 @@ public: + ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested_function->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - return std::make_shared(nested_function, argument); + return std::make_shared(nested_function, argument, params); } }; diff --git a/src/AggregateFunctions/AggregateFunctionMerge.h b/src/AggregateFunctions/AggregateFunctionMerge.h index 78e5c92c917..af9257d3c57 100644 --- a/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/src/AggregateFunctions/AggregateFunctionMerge.h @@ -29,8 +29,8 @@ private: AggregateFunctionPtr nested_func; public: - AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument) - : IAggregateFunctionHelper({argument}, nested_->getParameters()) + AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_) + : IAggregateFunctionHelper({argument}, params_) , nested_func(nested_) { const DataTypeAggregateFunction * data_type = typeid_cast(argument.get()); diff --git a/tests/queries/0_stateless/01156_pcg_deserialization.reference b/tests/queries/0_stateless/01156_pcg_deserialization.reference index e43b7ca3ceb..a41bc53d840 100644 --- a/tests/queries/0_stateless/01156_pcg_deserialization.reference +++ b/tests/queries/0_stateless/01156_pcg_deserialization.reference @@ -1,3 +1,6 @@ 5 5 5 5 5 5 +5 5 +5 5 +5 5 diff --git a/tests/queries/0_stateless/01156_pcg_deserialization.sh b/tests/queries/0_stateless/01156_pcg_deserialization.sh index 9c8ac29f32e..00ef86dce9c 100755 --- a/tests/queries/0_stateless/01156_pcg_deserialization.sh +++ b/tests/queries/0_stateless/01156_pcg_deserialization.sh @@ -4,16 +4,20 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) # shellcheck source=../shell_config.sh . "$CURDIR"/../shell_config.sh +declare -a functions=("groupArraySample" "groupUniqArray") declare -a engines=("Memory" "MergeTree order by n" "Log") -for engine in "${engines[@]}" +for func in "${functions[@]}" do - $CLICKHOUSE_CLIENT -q "drop table if exists t"; - $CLICKHOUSE_CLIENT -q "create table t (n UInt8, a1 AggregateFunction(groupArraySample(1), UInt8)) engine=$engine" - $CLICKHOUSE_CLIENT -q "insert into t select number % 5 as n, groupArraySampleState(1)(toUInt8(number)) from numbers(10) group by n" + for engine in "${engines[@]}" + do + $CLICKHOUSE_CLIENT -q "drop table if exists t"; + $CLICKHOUSE_CLIENT -q "create table t (n UInt8, a1 AggregateFunction($func(1), UInt8)) engine=$engine" + $CLICKHOUSE_CLIENT -q "insert into t select number % 5 as n, ${func}State(1)(toUInt8(number)) from numbers(10) group by n" - $CLICKHOUSE_CLIENT -q "select * from t format TSV" | $CLICKHOUSE_CLIENT -q "insert into t format TSV" - $CLICKHOUSE_CLIENT -q "select countDistinct(n), countDistinct(a1) from t" + $CLICKHOUSE_CLIENT -q "select * from t format TSV" | $CLICKHOUSE_CLIENT -q "insert into t format TSV" + $CLICKHOUSE_CLIENT -q "select countDistinct(n), countDistinct(a1) from t" - $CLICKHOUSE_CLIENT -q "drop table t"; + $CLICKHOUSE_CLIENT -q "drop table t"; + done done diff --git a/tests/queries/0_stateless/01159_combinators_with_parameters.reference b/tests/queries/0_stateless/01159_combinators_with_parameters.reference new file mode 100644 index 00000000000..cc0cb604bf3 --- /dev/null +++ b/tests/queries/0_stateless/01159_combinators_with_parameters.reference @@ -0,0 +1,20 @@ +AggregateFunction(topKArray(10), Array(String)) +AggregateFunction(topKDistinct(10), String) +AggregateFunction(topKForEach(10), Array(String)) +AggregateFunction(topKIf(10), String, UInt8) +AggregateFunction(topK(10), String) +AggregateFunction(topKOrNull(10), String) +AggregateFunction(topKOrDefault(10), String) +AggregateFunction(topKResample(10, 1, 2, 42), String, UInt64) +AggregateFunction(topK(10), String) +AggregateFunction(topKArrayResampleOrDefaultIf(10, 1, 2, 42), Array(String), UInt64, UInt8) +10 +10 +[10] +11 +10 +10 +10 +[1] +10 +[1] diff --git a/tests/queries/0_stateless/01159_combinators_with_parameters.sql b/tests/queries/0_stateless/01159_combinators_with_parameters.sql new file mode 100644 index 00000000000..69508d8e304 --- /dev/null +++ b/tests/queries/0_stateless/01159_combinators_with_parameters.sql @@ -0,0 +1,43 @@ +SELECT toTypeName(topKArrayState(10)([toString(number)])) FROM numbers(100); +SELECT toTypeName(topKDistinctState(10)(toString(number))) FROM numbers(100); +SELECT toTypeName(topKForEachState(10)([toString(number)])) FROM numbers(100); +SELECT toTypeName(topKIfState(10)(toString(number), number % 2)) FROM numbers(100); +SELECT toTypeName(topKMergeState(10)(state)) FROM (SELECT topKState(10)(toString(number)) as state FROM numbers(100)); +SELECT toTypeName(topKOrNullState(10)(toString(number))) FROM numbers(100); +SELECT toTypeName(topKOrDefaultState(10)(toString(number))) FROM numbers(100); +SELECT toTypeName(topKResampleState(10, 1, 2, 42)(toString(number), number)) FROM numbers(100); +SELECT toTypeName(topKState(10)(toString(number))) FROM numbers(100); +SELECT toTypeName(topKArrayResampleOrDefaultIfState(10, 1, 2, 42)([toString(number)], number, number % 2)) FROM numbers(100); + +CREATE TEMPORARY TABLE t0 AS SELECT quantileArrayState(0.10)([number]) FROM numbers(100); +CREATE TEMPORARY TABLE t1 AS SELECT quantileDistinctState(0.10)(number) FROM numbers(100); +CREATE TEMPORARY TABLE t2 AS SELECT quantileForEachState(0.10)([number]) FROM numbers(100); +CREATE TEMPORARY TABLE t3 AS SELECT quantileIfState(0.10)(number, number % 2) FROM numbers(100); +CREATE TEMPORARY TABLE t4 AS SELECT quantileMergeState(0.10)(state) FROM (SELECT quantileState(0.10)(number) as state FROM numbers(100)); +CREATE TEMPORARY TABLE t5 AS SELECT quantileOrNullState(0.10)(number) FROM numbers(100); +CREATE TEMPORARY TABLE t6 AS SELECT quantileOrDefaultState(0.10)(number) FROM numbers(100); +CREATE TEMPORARY TABLE t7 AS SELECT quantileResampleState(0.10, 1, 2, 42)(number, number) FROM numbers(100); +CREATE TEMPORARY TABLE t8 AS SELECT quantileState(0.10)(number) FROM numbers(100); +CREATE TEMPORARY TABLE t9 AS SELECT quantileArrayResampleOrDefaultIfState(0.10, 1, 2, 42)([number], number, number % 2) FROM numbers(100); + +INSERT INTO t0 SELECT quantileArrayState(0.10)([number]) FROM numbers(100); +INSERT INTO t1 SELECT quantileDistinctState(0.10)(number) FROM numbers(100); +INSERT INTO t2 SELECT quantileForEachState(0.10)([number]) FROM numbers(100); +INSERT INTO t3 SELECT quantileIfState(0.10)(number, number % 2) FROM numbers(100); +INSERT INTO t4 SELECT quantileMergeState(0.10)(state) FROM (SELECT quantileState(0.10)(number) as state FROM numbers(100)); +INSERT INTO t5 SELECT quantileOrNullState(0.10)(number) FROM numbers(100); +INSERT INTO t6 SELECT quantileOrDefaultState(0.10)(number) FROM numbers(100); +INSERT INTO t7 SELECT quantileResampleState(0.10, 1, 2, 42)(number, number) FROM numbers(100); +INSERT INTO t8 SELECT quantileState(0.10)(number) FROM numbers(100); +INSERT INTO t9 SELECT quantileArrayResampleOrDefaultIfState(0.10, 1, 2, 42)([number], number, number % 2) FROM numbers(100); + +SELECT round(quantileArrayMerge(0.10)((*,).1)) FROM t0; +SELECT round(quantileDistinctMerge(0.10)((*,).1)) FROM t1; +SELECT arrayMap(x -> round(x), quantileForEachMerge(0.10)((*,).1)) FROM t2; +SELECT round(quantileIfMerge(0.10)((*,).1)) FROM t3; +SELECT round(quantileMerge(0.10)((*,).1)) FROM t4; +SELECT round(quantileOrNullMerge(0.10)((*,).1)) FROM t5; +SELECT round(quantileOrDefaultMerge(0.10)((*,).1)) FROM t6; +SELECT arrayMap(x -> round(x), quantileResampleMerge(0.10, 1, 2, 42)((*,).1)) FROM t7; +SELECT round(quantileMerge(0.10)((*,).1)) FROM t8; +SELECT arrayMap(x -> round(x), quantileArrayResampleOrDefaultIfMerge(0.10, 1, 2, 42)((*,).1)) FROM t9;