Fix pcg deserialization (#24538)

* fix pcg deserialization

* Update 01156_pcg_deserialization.sh

* Update 01156_pcg_deserialization.sh

* Update 01156_pcg_deserialization.sh

* fix another bug

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: alexey-milovidov <milovidov@yandex-team.ru>
This commit is contained in:
tavplubix 2021-06-24 10:40:00 +03:00 committed by GitHub
parent a6d289c750
commit b1263c18ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 14 deletions

View File

@ -30,16 +30,16 @@ static IAggregateFunction * createWithNumericOrTimeType(const IDataType & argume
template <typename Trait, typename ... TArgs> template <typename Trait, typename ... TArgs>
inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, TArgs ... args) inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)
{ {
if (auto res = createWithNumericOrTimeType<GroupArrayNumericImpl, Trait>(*argument_type, argument_type, std::forward<TArgs>(args)...)) if (auto res = createWithNumericOrTimeType<GroupArrayNumericImpl, Trait>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
return AggregateFunctionPtr(res); return AggregateFunctionPtr(res);
WhichDataType which(argument_type); WhichDataType which(argument_type);
if (which.idx == TypeIndex::String) if (which.idx == TypeIndex::String)
return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeString, Trait>>(argument_type, std::forward<TArgs>(args)...); return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeString, Trait>>(argument_type, parameters, std::forward<TArgs>(args)...);
return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeGeneral, Trait>>(argument_type, std::forward<TArgs>(args)...); return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeGeneral, Trait>>(argument_type, parameters, std::forward<TArgs>(args)...);
// Link list implementation doesn't show noticeable performance improvement // Link list implementation doesn't show noticeable performance improvement
// if (which.idx == TypeIndex::String) // if (which.idx == TypeIndex::String)
@ -79,9 +79,9 @@ AggregateFunctionPtr createAggregateFunctionGroupArray(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!limit_size) if (!limit_size)
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<false, Sampler::NONE>>(argument_types[0]); return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<false, Sampler::NONE>>(argument_types[0], parameters);
else else
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::NONE>>(argument_types[0], max_elems); return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::NONE>>(argument_types[0], parameters, max_elems);
} }
AggregateFunctionPtr createAggregateFunctionGroupArraySample( AggregateFunctionPtr createAggregateFunctionGroupArraySample(
@ -114,7 +114,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySample(
else else
seed = thread_local_rng(); seed = thread_local_rng();
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::RNG>>(argument_types[0], max_elems, seed); return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::RNG>>(argument_types[0], parameters, max_elems, seed);
} }
} }

View File

@ -119,9 +119,9 @@ class GroupArrayNumericImpl final
public: public:
explicit GroupArrayNumericImpl( explicit GroupArrayNumericImpl(
const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456) const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>( : IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
{data_type_}, {}) {data_type_}, parameters_)
, max_elems(max_elems_) , max_elems(max_elems_)
, seed(seed_) , seed(seed_)
{ {
@ -421,9 +421,9 @@ class GroupArrayGeneralImpl final
UInt64 seed; UInt64 seed;
public: public:
GroupArrayGeneralImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456) GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>( : IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
{data_type_}, {}) {data_type_}, parameters_)
, data_type(this->argument_types[0]) , data_type(this->argument_types[0])
, max_elems(max_elems_) , max_elems(max_elems_)
, seed(seed_) , seed(seed_)
@ -696,8 +696,8 @@ class GroupArrayGeneralListImpl final
UInt64 max_elems; UInt64 max_elems;
public: public:
GroupArrayGeneralListImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) GroupArrayGeneralListImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, {}) : IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_)
, data_type(this->argument_types[0]) , data_type(this->argument_types[0])
, max_elems(max_elems_) , max_elems(max_elems_)
{ {

View File

@ -1248,7 +1248,7 @@ bool loadAtPosition(ReadBuffer & in, Memory<> & memory, char * & current);
struct PcgDeserializer struct PcgDeserializer
{ {
static void deserializePcg32(const pcg32_fast & rng, ReadBuffer & buf) static void deserializePcg32(pcg32_fast & rng, ReadBuffer & buf)
{ {
decltype(rng.state_) multiplier, increment, state; decltype(rng.state_) multiplier, increment, state;
readText(multiplier, buf); readText(multiplier, buf);
@ -1261,6 +1261,8 @@ struct PcgDeserializer
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect multiplier in pcg32: expected {}, got {}", rng.multiplier(), multiplier); throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect multiplier in pcg32: expected {}, got {}", rng.multiplier(), multiplier);
if (increment != rng.increment()) if (increment != rng.increment())
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect increment in pcg32: expected {}, got {}", rng.increment(), increment); throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect increment in pcg32: expected {}, got {}", rng.increment(), increment);
rng.state_ = state;
} }
}; };

View File

@ -0,0 +1,3 @@
5 5
5 5
5 5

View File

@ -0,0 +1,19 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
declare -a engines=("Memory" "MergeTree order by n" "Log")
for engine in "${engines[@]}"
do
$CLICKHOUSE_CLIENT -q "drop table if exists t";
$CLICKHOUSE_CLIENT -q "create table t (n UInt8, a1 AggregateFunction(groupArraySample(1), UInt8)) engine=$engine"
$CLICKHOUSE_CLIENT -q "insert into t select number % 5 as n, groupArraySampleState(1)(toUInt8(number)) from numbers(10) group by n"
$CLICKHOUSE_CLIENT -q "select * from t format TSV" | $CLICKHOUSE_CLIENT -q "insert into t format TSV"
$CLICKHOUSE_CLIENT -q "select countDistinct(n), countDistinct(a1) from t"
$CLICKHOUSE_CLIENT -q "drop table t";
done