improve checking for accuracy param in quantileGK

This commit is contained in:
taiyang-li 2023-02-16 18:27:50 +08:00
parent 966c5484e6
commit 384888421e
2 changed files with 30 additions and 9 deletions

View File

@ -26,6 +26,7 @@ namespace ErrorCodes
{ {
extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int LOGICAL_ERROR;
} }
template <typename> class QuantileTiming; template <typename> class QuantileTiming;
@ -70,7 +71,7 @@ private:
Float64 level = 0.5; Float64 level = 0.5;
/// Used when function name is "quantileGK" or "quantilesGK" /// Used when function name is "quantileGK" or "quantilesGK"
size_t accuracy = 10000; ssize_t accuracy = 10000;
DataTypePtr & argument_type; DataTypePtr & argument_type;
@ -83,10 +84,27 @@ public:
, argument_type(this->argument_types[0]) , argument_type(this->argument_types[0])
{ {
if (!returns_many && levels.size() > 1) if (!returns_many && levels.size() > 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require one parameter or less", getName()); throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require one level parameter or less", getName());
if constexpr (is_quantile_gk) if constexpr (is_quantile_gk)
accuracy = params[0].get<UInt64>(); {
const auto & accuracy_field = params[0];
if (!isInt64OrUInt64FieldType(accuracy_field.getType()))
throw Exception(
ErrorCodes::LOGICAL_ERROR, "Aggregate function {} require accuracy parameter with integer type", getName());
if (accuracy_field.getType() == Field::Types::Int64)
accuracy = accuracy_field.get<Int64>();
else
accuracy = accuracy_field.get<UInt64>();
if (accuracy <= 0)
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Aggregate function {} require accuracy parameter with positive value but is {}",
getName(),
accuracy);
}
} }
String getName() const override { return Name::name; } String getName() const override { return Name::name; }

View File

@ -59,6 +59,7 @@ public:
} }
bool isCompressed() const { return compressed; } bool isCompressed() const { return compressed; }
void setCompressed() { compressed = true; }
void insert(T x) void insert(T x)
{ {
@ -115,6 +116,9 @@ public:
void compress() void compress()
{ {
if (compressed)
return;
withHeadBufferInserted(); withHeadBufferInserted();
doCompress(2 * relative_error * count); doCompress(2 * relative_error * count);
@ -124,9 +128,6 @@ public:
void merge(const GKSampler & other) void merge(const GKSampler & other)
{ {
if (!head_sampled.empty() || !other.head_sampled.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Current buffer needs to be compressed before merge");
if (other.count == 0) if (other.count == 0)
return; return;
else if (count == 0) else if (count == 0)
@ -169,6 +170,8 @@ public:
// `max(g_ab + delta_ab) <= floor(2 * eps_ab * (n_a + n_b))` since // `max(g_ab + delta_ab) <= floor(2 * eps_ab * (n_a + n_b))` since
// `max(g_ab + delta_ab) <= floor(2 * eps_a * n_a) + floor(2 * eps_b * n_b)` // `max(g_ab + delta_ab) <= floor(2 * eps_a * n_a) + floor(2 * eps_b * n_b)`
// Finally, one can see how the `insert(x)` operation can be expressed as `merge([(x, 1, 0])` // Finally, one can see how the `insert(x)` operation can be expressed as `merge([(x, 1, 0])`
compress();
backup_sampled.clear(); backup_sampled.clear();
backup_sampled.reserve(sampled.size() + other.sampled.size()); backup_sampled.reserve(sampled.size() + other.sampled.size());
double merged_relative_error = std::max(relative_error, other.relative_error); double merged_relative_error = std::max(relative_error, other.relative_error);
@ -249,9 +252,6 @@ public:
readFloatBinary<double>(relative_error, buf); readFloatBinary<double>(relative_error, buf);
readIntBinary<size_t>(count, buf); readIntBinary<size_t>(count, buf);
/// Always compress before serialization
compressed = true;
size_t sampled_len = 0; size_t sampled_len = 0;
readIntBinary<size_t>(sampled_len, buf); readIntBinary<size_t>(sampled_len, buf);
sampled.resize(sampled_len); sampled.resize(sampled_len);
@ -427,6 +427,7 @@ public:
void serialize(WriteBuffer & buf) const void serialize(WriteBuffer & buf) const
{ {
/// Always compress before serialization
if (!data.isCompressed()) if (!data.isCompressed())
data.compress(); data.compress();
@ -436,6 +437,8 @@ public:
void deserialize(ReadBuffer & buf) void deserialize(ReadBuffer & buf)
{ {
data.read(buf); data.read(buf);
data.setCompressed();
} }
/// Get the value of the `level` quantile. The level must be between 0 and 1. /// Get the value of the `level` quantile. The level must be between 0 and 1.