From 4c87ec04b45e505c84aa7534f44cf86162b6a049 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 15 Nov 2015 06:11:24 +0300 Subject: [PATCH] dbms: quantileExact: initial implementation [#METR-18778]. --- .../AggregateFunctionGroupArray.h | 2 +- .../AggregateFunctionQuantile.h | 6 +- .../AggregateFunctionQuantileDeterministic.h | 6 +- .../AggregateFunctionQuantileExact.h | 226 ++++++++++++++++++ .../AggregateFunctionQuantileTiming.h | 2 +- .../AggregateFunctionFactory.cpp | 2 + .../AggregateFunctionQuantileExact.cpp | 66 +++++ 7 files changed, 302 insertions(+), 8 deletions(-) create mode 100644 dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileExact.h create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionQuantileExact.cpp diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionGroupArray.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionGroupArray.h index 4052fbfb320..435a15fb7b5 100644 --- a/dbms/include/DB/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionGroupArray.h @@ -15,7 +15,7 @@ namespace DB struct AggregateFunctionGroupArrayData { - Array value; /// TODO Добавить MemoryTracker + Array value; /// TODO Добавить MemoryTracker /// TODO Оптимизация для распространённых типов. }; diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantile.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantile.h index 752fab20694..e46fb9828d8 100644 --- a/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantile.h +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantile.h @@ -114,7 +114,7 @@ class AggregateFunctionQuantiles final private: using Sample = typename AggregateFunctionQuantileData::Sample; - typedef std::vector Levels; + using Levels = std::vector; Levels levels; DataTypePtr type; @@ -185,14 +185,14 @@ public: ColumnFloat64::Container_t & data_to = static_cast(arr_to.getData()).getData(); for (size_t i = 0; i < size; ++i) - data_to.push_back(sample.quantileInterpolated(levels[i])); + data_to.push_back(sample.quantileInterpolated(levels[i])); } else { typename ColumnVector::Container_t & data_to = static_cast &>(arr_to.getData()).getData(); for (size_t i = 0; i < size; ++i) - data_to.push_back(sample.quantileInterpolated(levels[i])); + data_to.push_back(sample.quantileInterpolated(levels[i])); } } }; diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileDeterministic.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileDeterministic.h index 8ae543b4ab3..c575007bb30 100644 --- a/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileDeterministic.h +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileDeterministic.h @@ -123,7 +123,7 @@ class AggregateFunctionQuantilesDeterministic final private: using Sample = typename AggregateFunctionQuantileDeterministicData::Sample; - typedef std::vector Levels; + using Levels = std::vector; Levels levels; DataTypePtr type; @@ -199,14 +199,14 @@ public: ColumnFloat64::Container_t & data_to = static_cast(arr_to.getData()).getData(); for (size_t i = 0; i < size; ++i) - data_to.push_back(sample.quantileInterpolated(levels[i])); + data_to.push_back(sample.quantileInterpolated(levels[i])); } else { typename ColumnVector::Container_t & data_to = static_cast &>(arr_to.getData()).getData(); for (size_t i = 0; i < size; ++i) - data_to.push_back(sample.quantileInterpolated(levels[i])); + data_to.push_back(sample.quantileInterpolated(levels[i])); } } }; diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileExact.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileExact.h new file mode 100644 index 00000000000..734c3a4c213 --- /dev/null +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileExact.h @@ -0,0 +1,226 @@ +#pragma once + +#include + +#include + +#include +#include + +#include +#include + +#include + +#include + + +namespace DB +{ + + +/** В качестве состояния используется массив, в который складываются все значения. + * NOTE Если различных значений мало, то это не оптимально. + * Для 8 и 16-битных значений возможно, было бы лучше использовать lookup-таблицу. + */ +template +struct AggregateFunctionQuantileExactData +{ + using Array = PODArray; + Array array; +}; + + +/** Точно вычисляет квантиль. + * В качестве типа аргумента может быть только числовой тип (в том числе, дата и дата-с-временем). + * Тип результата совпадает с типом аргумента. + */ +template +class AggregateFunctionQuantileExact final + : public IUnaryAggregateFunction, AggregateFunctionQuantileExact> +{ +private: + double level; + DataTypePtr type; + +public: + AggregateFunctionQuantileExact(double level_ = 0.5) : level(level_) {} + + String getName() const override { return "quantileExact"; } + + DataTypePtr getReturnType() const override + { + return type; + } + + void setArgument(const DataTypePtr & argument) override + { + type = argument; + } + + void setParameters(const Array & params) override + { + if (params.size() != 1) + throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + level = apply_visitor(FieldVisitorConvertToNumber(), params[0]); + } + + void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const + { + this->data(place).array.push_back(static_cast &>(column).getData()[row_num]); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override + { + this->data(place).array.insert(this->data(rhs).array.begin(), this->data(rhs).array.end()); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + const auto & array = this->data(place).array; + + size_t size = array.size(); + writeVarUInt(size, buf); + buf.write(reinterpret_cast(&array[0]), size * sizeof(array[0])); + } + + void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override + { + auto & array = this->data(place).array; + + size_t size = 0; + readVarUInt(size, buf); + size_t old_size = array.size(); + array.resize(old_size + size); + buf.read(reinterpret_cast(&array[old_size]), size * sizeof(array[0])); + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + /// Сортировка массива не будет считаться нарушением константности. + auto & array = const_cast::Array &>(this->data(place).array); + + T quantile = T(); + + if (!array.empty()) + { + size_t n = level < 1 + ? level * array.size() + : (array.size() - 1); + + std::nth_element(array.begin(), array.begin() + n, array.end()); /// NOTE Можно придумать алгоритм radix-select. + + quantile = array[n]; + } + + static_cast &>(to).getData().push_back(quantile); + } +}; + + +/** То же самое, но позволяет вычислить сразу несколько квантилей. + * Для этого, принимает в качестве параметров несколько уровней. Пример: quantilesExact(0.5, 0.8, 0.9, 0.95)(ConnectTiming). + * Возвращает массив результатов. + */ +template +class AggregateFunctionQuantilesExact final + : public IUnaryAggregateFunction, AggregateFunctionQuantilesExact> +{ +private: + using Levels = std::vector; + Levels levels; + DataTypePtr type; + +public: + String getName() const override { return "quantilesExact"; } + + DataTypePtr getReturnType() const override + { + return new DataTypeArray(type); + } + + void setArgument(const DataTypePtr & argument) override + { + type = argument; + } + + void setParameters(const Array & params) override + { + if (params.empty()) + throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + size_t size = params.size(); + levels.resize(size); + + for (size_t i = 0; i < size; ++i) + levels[i] = apply_visitor(FieldVisitorConvertToNumber(), params[i]); + } + + void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const + { + this->data(place).array.push_back(static_cast &>(column).getData()[row_num]); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override + { + this->data(place).array.insert(this->data(rhs).array.begin(), this->data(rhs).array.end()); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + const auto & array = this->data(place).array; + + size_t size = array.size(); + writeVarUInt(size, buf); + buf.write(reinterpret_cast(&array[0]), size * sizeof(array[0])); + } + + void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override + { + auto & array = this->data(place).array; + + size_t size = 0; + readVarUInt(size, buf); + size_t old_size = array.size(); + array.resize(old_size + size); + buf.read(reinterpret_cast(&array[old_size]), size * sizeof(array[0])); + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + /// Сортировка массива не будет считаться нарушением константности. + auto & array = const_cast::Array &>(this->data(place).array); + + ColumnArray & arr_to = static_cast(to); + ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets(); + + size_t num_levels = levels.size(); + offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + num_levels); + + typename ColumnVector::Container_t & data_to = static_cast &>(arr_to.getData()).getData(); + + if (!array.empty()) + { + size_t prev_n = 0; + for (const auto & level : levels) + { + size_t n = level < 1 + ? level * array.size() + : (array.size() - 1); + + std::nth_element(array.begin() + prev_n, array.begin() + n, array.end()); + + data_to.push_back(array[n]); + prev_n = n; + } + } + else + { + for (size_t i = 0; i < num_levels; ++i) + data_to.push_back(T()); + } + } +}; + +} diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileTiming.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileTiming.h index d6a18d77e28..4331a37cd8b 100644 --- a/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileTiming.h +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileTiming.h @@ -733,7 +733,7 @@ template class AggregateFunctionQuantilesTimingWeighted final : public IAggregateFunctionHelper { private: - typedef std::vector Levels; + using Levels = std::vector; Levels levels; public: diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index eeeff670644..df361656df6 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -60,6 +60,7 @@ void registerAggregateFunctionCount(AggregateFunctionFactory & factory); void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory); void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory & factory); void registerAggregateFunctionsQuantile(AggregateFunctionFactory & factory); +void registerAggregateFunctionsQuantileExact(AggregateFunctionFactory & factory); void registerAggregateFunctionsQuantileDeterministic(AggregateFunctionFactory & factory); void registerAggregateFunctionsQuantileTiming(AggregateFunctionFactory & factory); void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory); @@ -90,6 +91,7 @@ AggregateFunctionFactory::AggregateFunctionFactory() registerAggregateFunctionSum(*this); registerAggregateFunctionsUniq(*this); registerAggregateFunctionUniqUpTo(*this); + registerAggregateFunctionsQuantileExact(*this); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionQuantileExact.cpp b/dbms/src/AggregateFunctions/AggregateFunctionQuantileExact.cpp new file mode 100644 index 00000000000..b81a2ffa765 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionQuantileExact.cpp @@ -0,0 +1,66 @@ +#include +#include +#include + +namespace DB +{ + +namespace +{ + +AggregateFunctionPtr createAggregateFunctionQuantileExact(const std::string & name, const DataTypes & argument_types) +{ + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + const IDataType & argument_type = *argument_types[0]; + + if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantileExact; + else + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); +} + + +AggregateFunctionPtr createAggregateFunctionQuantilesExact(const std::string & name, const DataTypes & argument_types) +{ + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + const IDataType & argument_type = *argument_types[0]; + + if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else if (typeid_cast(&argument_type)) return new AggregateFunctionQuantilesExact; + else + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); +} + +} + +void registerAggregateFunctionsQuantileExact(AggregateFunctionFactory & factory) +{ + factory.registerFunction({"quantileExact", "medianExact"}, createAggregateFunctionQuantileExact); + factory.registerFunction({"quantilesExact"}, createAggregateFunctionQuantilesExact); +} + +}