Fix pcg deserialization (#24538)

* fix pcg deserialization

* Update 01156_pcg_deserialization.sh

* Update 01156_pcg_deserialization.sh

* Update 01156_pcg_deserialization.sh

* fix another bug

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: alexey-milovidov <milovidov@yandex-team.ru>
This commit is contained in:
tavplubix 2021-06-24 10:40:00 +03:00 committed by GitHub
parent a6d289c750
commit b1263c18ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 14 deletions

View File

@ -30,16 +30,16 @@ static IAggregateFunction * createWithNumericOrTimeType(const IDataType & argume
template <typename Trait, typename ... TArgs>
inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, TArgs ... args)
inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)
{
if (auto res = createWithNumericOrTimeType<GroupArrayNumericImpl, Trait>(*argument_type, argument_type, std::forward<TArgs>(args)...))
if (auto res = createWithNumericOrTimeType<GroupArrayNumericImpl, Trait>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
return AggregateFunctionPtr(res);
WhichDataType which(argument_type);
if (which.idx == TypeIndex::String)
return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeString, Trait>>(argument_type, std::forward<TArgs>(args)...);
return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeString, Trait>>(argument_type, parameters, std::forward<TArgs>(args)...);
return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeGeneral, Trait>>(argument_type, std::forward<TArgs>(args)...);
return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeGeneral, Trait>>(argument_type, parameters, std::forward<TArgs>(args)...);
// Link list implementation doesn't show noticeable performance improvement
// if (which.idx == TypeIndex::String)
@ -79,9 +79,9 @@ AggregateFunctionPtr createAggregateFunctionGroupArray(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!limit_size)
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<false, Sampler::NONE>>(argument_types[0]);
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<false, Sampler::NONE>>(argument_types[0], parameters);
else
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::NONE>>(argument_types[0], max_elems);
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::NONE>>(argument_types[0], parameters, max_elems);
}
AggregateFunctionPtr createAggregateFunctionGroupArraySample(
@ -114,7 +114,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySample(
else
seed = thread_local_rng();
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::RNG>>(argument_types[0], max_elems, seed);
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait<true, Sampler::RNG>>(argument_types[0], parameters, max_elems, seed);
}
}

View File

@ -119,9 +119,9 @@ class GroupArrayNumericImpl final
public:
explicit GroupArrayNumericImpl(
const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
{data_type_}, {})
{data_type_}, parameters_)
, max_elems(max_elems_)
, seed(seed_)
{
@ -421,9 +421,9 @@ class GroupArrayGeneralImpl final
UInt64 seed;
public:
GroupArrayGeneralImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
{data_type_}, {})
{data_type_}, parameters_)
, data_type(this->argument_types[0])
, max_elems(max_elems_)
, seed(seed_)
@ -696,8 +696,8 @@ class GroupArrayGeneralListImpl final
UInt64 max_elems;
public:
GroupArrayGeneralListImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, {})
GroupArrayGeneralListImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_)
, data_type(this->argument_types[0])
, max_elems(max_elems_)
{

View File

@ -1248,7 +1248,7 @@ bool loadAtPosition(ReadBuffer & in, Memory<> & memory, char * & current);
struct PcgDeserializer
{
static void deserializePcg32(const pcg32_fast & rng, ReadBuffer & buf)
static void deserializePcg32(pcg32_fast & rng, ReadBuffer & buf)
{
decltype(rng.state_) multiplier, increment, state;
readText(multiplier, buf);
@ -1261,6 +1261,8 @@ struct PcgDeserializer
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect multiplier in pcg32: expected {}, got {}", rng.multiplier(), multiplier);
if (increment != rng.increment())
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect increment in pcg32: expected {}, got {}", rng.increment(), increment);
rng.state_ = state;
}
};

View File

@ -0,0 +1,3 @@
5 5
5 5
5 5

View File

@ -0,0 +1,19 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
declare -a engines=("Memory" "MergeTree order by n" "Log")
for engine in "${engines[@]}"
do
$CLICKHOUSE_CLIENT -q "drop table if exists t";
$CLICKHOUSE_CLIENT -q "create table t (n UInt8, a1 AggregateFunction(groupArraySample(1), UInt8)) engine=$engine"
$CLICKHOUSE_CLIENT -q "insert into t select number % 5 as n, groupArraySampleState(1)(toUInt8(number)) from numbers(10) group by n"
$CLICKHOUSE_CLIENT -q "select * from t format TSV" | $CLICKHOUSE_CLIENT -q "insert into t format TSV"
$CLICKHOUSE_CLIENT -q "select countDistinct(n), countDistinct(a1) from t"
$CLICKHOUSE_CLIENT -q "drop table t";
done