From 384888421e5a809c9d755f6e2b94e632f1305a4a Mon Sep 17 00:00:00 2001 From: taiyang-li <654010905@qq.com> Date: Thu, 16 Feb 2023 18:27:50 +0800 Subject: [PATCH] improve checking for accuracy param in quantileGK --- .../AggregateFunctionQuantile.h | 24 ++++++++++++++++--- src/AggregateFunctions/QuantileGK.h | 15 +++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionQuantile.h b/src/AggregateFunctions/AggregateFunctionQuantile.h index 6bc058d9e30..deb83d342b0 100644 --- a/src/AggregateFunctions/AggregateFunctionQuantile.h +++ b/src/AggregateFunctions/AggregateFunctionQuantile.h @@ -26,6 +26,7 @@ namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int LOGICAL_ERROR; } template class QuantileTiming; @@ -70,7 +71,7 @@ private: Float64 level = 0.5; /// Used when function name is "quantileGK" or "quantilesGK" - size_t accuracy = 10000; + ssize_t accuracy = 10000; DataTypePtr & argument_type; @@ -83,10 +84,27 @@ public: , argument_type(this->argument_types[0]) { if (!returns_many && levels.size() > 1) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require one parameter or less", getName()); + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require one level parameter or less", getName()); if constexpr (is_quantile_gk) - accuracy = params[0].get(); + { + const auto & accuracy_field = params[0]; + if (!isInt64OrUInt64FieldType(accuracy_field.getType())) + throw Exception( + ErrorCodes::LOGICAL_ERROR, "Aggregate function {} require accuracy parameter with integer type", getName()); + + if (accuracy_field.getType() == Field::Types::Int64) + accuracy = accuracy_field.get(); + else + accuracy = accuracy_field.get(); + + if (accuracy <= 0) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Aggregate function {} require accuracy parameter with positive value but is {}", + getName(), + accuracy); + } } String getName() const override { return Name::name; } diff --git a/src/AggregateFunctions/QuantileGK.h b/src/AggregateFunctions/QuantileGK.h index 1e3d361d29b..c75005d65db 100644 --- a/src/AggregateFunctions/QuantileGK.h +++ b/src/AggregateFunctions/QuantileGK.h @@ -59,6 +59,7 @@ public: } bool isCompressed() const { return compressed; } + void setCompressed() { compressed = true; } void insert(T x) { @@ -115,6 +116,9 @@ public: void compress() { + if (compressed) + return; + withHeadBufferInserted(); doCompress(2 * relative_error * count); @@ -124,9 +128,6 @@ public: void merge(const GKSampler & other) { - if (!head_sampled.empty() || !other.head_sampled.empty()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Current buffer needs to be compressed before merge"); - if (other.count == 0) return; else if (count == 0) @@ -169,6 +170,8 @@ public: // `max(g_ab + delta_ab) <= floor(2 * eps_ab * (n_a + n_b))` since // `max(g_ab + delta_ab) <= floor(2 * eps_a * n_a) + floor(2 * eps_b * n_b)` // Finally, one can see how the `insert(x)` operation can be expressed as `merge([(x, 1, 0])` + compress(); + backup_sampled.clear(); backup_sampled.reserve(sampled.size() + other.sampled.size()); double merged_relative_error = std::max(relative_error, other.relative_error); @@ -249,9 +252,6 @@ public: readFloatBinary(relative_error, buf); readIntBinary(count, buf); - /// Always compress before serialization - compressed = true; - size_t sampled_len = 0; readIntBinary(sampled_len, buf); sampled.resize(sampled_len); @@ -427,6 +427,7 @@ public: void serialize(WriteBuffer & buf) const { + /// Always compress before serialization if (!data.isCompressed()) data.compress(); @@ -436,6 +437,8 @@ public: void deserialize(ReadBuffer & buf) { data.read(buf); + + data.setCompressed(); } /// Get the value of the `level` quantile. The level must be between 0 and 1.