From b1263c18ee15f0aa40c5f94dee11c7e1e0e27e74 Mon Sep 17 00:00:00 2001 From: tavplubix Date: Thu, 24 Jun 2021 10:40:00 +0300 Subject: [PATCH] Fix pcg deserialization (#24538) * fix pcg deserialization * Update 01156_pcg_deserialization.sh * Update 01156_pcg_deserialization.sh * Update 01156_pcg_deserialization.sh * fix another bug Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: alexey-milovidov --- .../AggregateFunctionGroupArray.cpp | 14 +++++++------- .../AggregateFunctionGroupArray.h | 12 ++++++------ src/IO/ReadHelpers.h | 4 +++- .../01156_pcg_deserialization.reference | 3 +++ .../0_stateless/01156_pcg_deserialization.sh | 19 +++++++++++++++++++ 5 files changed, 38 insertions(+), 14 deletions(-) create mode 100644 tests/queries/0_stateless/01156_pcg_deserialization.reference create mode 100755 tests/queries/0_stateless/01156_pcg_deserialization.sh diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp index 73039dc4dec..5a9fd778277 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp @@ -30,16 +30,16 @@ static IAggregateFunction * createWithNumericOrTimeType(const IDataType & argume template -inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, TArgs ... args) +inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args) { - if (auto res = createWithNumericOrTimeType(*argument_type, argument_type, std::forward(args)...)) + if (auto res = createWithNumericOrTimeType(*argument_type, argument_type, parameters, std::forward(args)...)) return AggregateFunctionPtr(res); WhichDataType which(argument_type); if (which.idx == TypeIndex::String) - return std::make_shared>(argument_type, std::forward(args)...); + return std::make_shared>(argument_type, parameters, std::forward(args)...); - return std::make_shared>(argument_type, std::forward(args)...); + return std::make_shared>(argument_type, parameters, std::forward(args)...); // Link list implementation doesn't show noticeable performance improvement // if (which.idx == TypeIndex::String) @@ -79,9 +79,9 @@ AggregateFunctionPtr createAggregateFunctionGroupArray( ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!limit_size) - return createAggregateFunctionGroupArrayImpl>(argument_types[0]); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters); else - return createAggregateFunctionGroupArrayImpl>(argument_types[0], max_elems); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems); } AggregateFunctionPtr createAggregateFunctionGroupArraySample( @@ -114,7 +114,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySample( else seed = thread_local_rng(); - return createAggregateFunctionGroupArrayImpl>(argument_types[0], max_elems, seed); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems, seed); } } diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.h b/src/AggregateFunctions/AggregateFunctionGroupArray.h index 06292992a2f..a78ce89ce5a 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -119,9 +119,9 @@ class GroupArrayNumericImpl final public: explicit GroupArrayNumericImpl( - const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max(), UInt64 seed_ = 123456) + const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits::max(), UInt64 seed_ = 123456) : IAggregateFunctionDataHelper, GroupArrayNumericImpl>( - {data_type_}, {}) + {data_type_}, parameters_) , max_elems(max_elems_) , seed(seed_) { @@ -421,9 +421,9 @@ class GroupArrayGeneralImpl final UInt64 seed; public: - GroupArrayGeneralImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max(), UInt64 seed_ = 123456) + GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits::max(), UInt64 seed_ = 123456) : IAggregateFunctionDataHelper, GroupArrayGeneralImpl>( - {data_type_}, {}) + {data_type_}, parameters_) , data_type(this->argument_types[0]) , max_elems(max_elems_) , seed(seed_) @@ -696,8 +696,8 @@ class GroupArrayGeneralListImpl final UInt64 max_elems; public: - GroupArrayGeneralListImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max()) - : IAggregateFunctionDataHelper, GroupArrayGeneralListImpl>({data_type_}, {}) + GroupArrayGeneralListImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits::max()) + : IAggregateFunctionDataHelper, GroupArrayGeneralListImpl>({data_type_}, parameters_) , data_type(this->argument_types[0]) , max_elems(max_elems_) { diff --git a/src/IO/ReadHelpers.h b/src/IO/ReadHelpers.h index a772d4ccd69..ffcfeea3827 100644 --- a/src/IO/ReadHelpers.h +++ b/src/IO/ReadHelpers.h @@ -1248,7 +1248,7 @@ bool loadAtPosition(ReadBuffer & in, Memory<> & memory, char * & current); struct PcgDeserializer { - static void deserializePcg32(const pcg32_fast & rng, ReadBuffer & buf) + static void deserializePcg32(pcg32_fast & rng, ReadBuffer & buf) { decltype(rng.state_) multiplier, increment, state; readText(multiplier, buf); @@ -1261,6 +1261,8 @@ struct PcgDeserializer throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect multiplier in pcg32: expected {}, got {}", rng.multiplier(), multiplier); if (increment != rng.increment()) throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect increment in pcg32: expected {}, got {}", rng.increment(), increment); + + rng.state_ = state; } }; diff --git a/tests/queries/0_stateless/01156_pcg_deserialization.reference b/tests/queries/0_stateless/01156_pcg_deserialization.reference new file mode 100644 index 00000000000..e43b7ca3ceb --- /dev/null +++ b/tests/queries/0_stateless/01156_pcg_deserialization.reference @@ -0,0 +1,3 @@ +5 5 +5 5 +5 5 diff --git a/tests/queries/0_stateless/01156_pcg_deserialization.sh b/tests/queries/0_stateless/01156_pcg_deserialization.sh new file mode 100755 index 00000000000..9c8ac29f32e --- /dev/null +++ b/tests/queries/0_stateless/01156_pcg_deserialization.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +declare -a engines=("Memory" "MergeTree order by n" "Log") + +for engine in "${engines[@]}" +do + $CLICKHOUSE_CLIENT -q "drop table if exists t"; + $CLICKHOUSE_CLIENT -q "create table t (n UInt8, a1 AggregateFunction(groupArraySample(1), UInt8)) engine=$engine" + $CLICKHOUSE_CLIENT -q "insert into t select number % 5 as n, groupArraySampleState(1)(toUInt8(number)) from numbers(10) group by n" + + $CLICKHOUSE_CLIENT -q "select * from t format TSV" | $CLICKHOUSE_CLIENT -q "insert into t format TSV" + $CLICKHOUSE_CLIENT -q "select countDistinct(n), countDistinct(a1) from t" + + $CLICKHOUSE_CLIENT -q "drop table t"; +done