mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-27 18:12:02 +00:00
dbms: quantileExact: initial implementation [#METR-18778].
This commit is contained in:
parent
959ae7cd78
commit
4c87ec04b4
@ -15,7 +15,7 @@ namespace DB
|
||||
|
||||
struct AggregateFunctionGroupArrayData
|
||||
{
|
||||
Array value; /// TODO Добавить MemoryTracker
|
||||
Array value; /// TODO Добавить MemoryTracker /// TODO Оптимизация для распространённых типов.
|
||||
};
|
||||
|
||||
|
||||
|
@ -114,7 +114,7 @@ class AggregateFunctionQuantiles final
|
||||
private:
|
||||
using Sample = typename AggregateFunctionQuantileData<ArgumentFieldType>::Sample;
|
||||
|
||||
typedef std::vector<double> Levels;
|
||||
using Levels = std::vector<double>;
|
||||
Levels levels;
|
||||
DataTypePtr type;
|
||||
|
||||
@ -185,14 +185,14 @@ public:
|
||||
ColumnFloat64::Container_t & data_to = static_cast<ColumnFloat64 &>(arr_to.getData()).getData();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
}
|
||||
else
|
||||
{
|
||||
typename ColumnVector<ArgumentFieldType>::Container_t & data_to = static_cast<ColumnVector<ArgumentFieldType> &>(arr_to.getData()).getData();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -123,7 +123,7 @@ class AggregateFunctionQuantilesDeterministic final
|
||||
private:
|
||||
using Sample = typename AggregateFunctionQuantileDeterministicData<ArgumentFieldType>::Sample;
|
||||
|
||||
typedef std::vector<double> Levels;
|
||||
using Levels = std::vector<double>;
|
||||
Levels levels;
|
||||
DataTypePtr type;
|
||||
|
||||
@ -199,14 +199,14 @@ public:
|
||||
ColumnFloat64::Container_t & data_to = static_cast<ColumnFloat64 &>(arr_to.getData()).getData();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
}
|
||||
else
|
||||
{
|
||||
typename ColumnVector<ArgumentFieldType>::Container_t & data_to = static_cast<ColumnVector<ArgumentFieldType> &>(arr_to.getData()).getData();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
data_to.push_back(sample.quantileInterpolated(levels[i]));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -0,0 +1,226 @@
|
||||
#pragma once
|
||||
|
||||
#include <DB/Common/PODArray.h>
|
||||
|
||||
#include <DB/Core/FieldVisitors.h>
|
||||
|
||||
#include <DB/IO/WriteHelpers.h>
|
||||
#include <DB/IO/ReadHelpers.h>
|
||||
|
||||
#include <DB/DataTypes/DataTypesNumberFixed.h>
|
||||
#include <DB/DataTypes/DataTypeArray.h>
|
||||
|
||||
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
|
||||
|
||||
#include <DB/Columns/ColumnArray.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
|
||||
/** В качестве состояния используется массив, в который складываются все значения.
|
||||
* NOTE Если различных значений мало, то это не оптимально.
|
||||
* Для 8 и 16-битных значений возможно, было бы лучше использовать lookup-таблицу.
|
||||
*/
|
||||
template <typename T>
|
||||
struct AggregateFunctionQuantileExactData
|
||||
{
|
||||
using Array = PODArray<T>;
|
||||
Array array;
|
||||
};
|
||||
|
||||
|
||||
/** Точно вычисляет квантиль.
|
||||
* В качестве типа аргумента может быть только числовой тип (в том числе, дата и дата-с-временем).
|
||||
* Тип результата совпадает с типом аргумента.
|
||||
*/
|
||||
template <typename T>
|
||||
class AggregateFunctionQuantileExact final
|
||||
: public IUnaryAggregateFunction<AggregateFunctionQuantileExactData<T>, AggregateFunctionQuantileExact<T>>
|
||||
{
|
||||
private:
|
||||
double level;
|
||||
DataTypePtr type;
|
||||
|
||||
public:
|
||||
AggregateFunctionQuantileExact(double level_ = 0.5) : level(level_) {}
|
||||
|
||||
String getName() const override { return "quantileExact"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return type;
|
||||
}
|
||||
|
||||
void setArgument(const DataTypePtr & argument) override
|
||||
{
|
||||
type = argument;
|
||||
}
|
||||
|
||||
void setParameters(const Array & params) override
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
level = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
|
||||
}
|
||||
|
||||
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const
|
||||
{
|
||||
this->data(place).array.push_back(static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
|
||||
{
|
||||
this->data(place).array.insert(this->data(rhs).array.begin(), this->data(rhs).array.end());
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
const auto & array = this->data(place).array;
|
||||
|
||||
size_t size = array.size();
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(reinterpret_cast<const char *>(&array[0]), size * sizeof(array[0]));
|
||||
}
|
||||
|
||||
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
|
||||
{
|
||||
auto & array = this->data(place).array;
|
||||
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
size_t old_size = array.size();
|
||||
array.resize(old_size + size);
|
||||
buf.read(reinterpret_cast<char *>(&array[old_size]), size * sizeof(array[0]));
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
/// Сортировка массива не будет считаться нарушением константности.
|
||||
auto & array = const_cast<typename AggregateFunctionQuantileExactData<T>::Array &>(this->data(place).array);
|
||||
|
||||
T quantile = T();
|
||||
|
||||
if (!array.empty())
|
||||
{
|
||||
size_t n = level < 1
|
||||
? level * array.size()
|
||||
: (array.size() - 1);
|
||||
|
||||
std::nth_element(array.begin(), array.begin() + n, array.end()); /// NOTE Можно придумать алгоритм radix-select.
|
||||
|
||||
quantile = array[n];
|
||||
}
|
||||
|
||||
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/** То же самое, но позволяет вычислить сразу несколько квантилей.
|
||||
* Для этого, принимает в качестве параметров несколько уровней. Пример: quantilesExact(0.5, 0.8, 0.9, 0.95)(ConnectTiming).
|
||||
* Возвращает массив результатов.
|
||||
*/
|
||||
template <typename T>
|
||||
class AggregateFunctionQuantilesExact final
|
||||
: public IUnaryAggregateFunction<AggregateFunctionQuantileExactData<T>, AggregateFunctionQuantilesExact<T>>
|
||||
{
|
||||
private:
|
||||
using Levels = std::vector<double>;
|
||||
Levels levels;
|
||||
DataTypePtr type;
|
||||
|
||||
public:
|
||||
String getName() const override { return "quantilesExact"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return new DataTypeArray(type);
|
||||
}
|
||||
|
||||
void setArgument(const DataTypePtr & argument) override
|
||||
{
|
||||
type = argument;
|
||||
}
|
||||
|
||||
void setParameters(const Array & params) override
|
||||
{
|
||||
if (params.empty())
|
||||
throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
size_t size = params.size();
|
||||
levels.resize(size);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
levels[i] = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[i]);
|
||||
}
|
||||
|
||||
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const
|
||||
{
|
||||
this->data(place).array.push_back(static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
|
||||
{
|
||||
this->data(place).array.insert(this->data(rhs).array.begin(), this->data(rhs).array.end());
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
const auto & array = this->data(place).array;
|
||||
|
||||
size_t size = array.size();
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(reinterpret_cast<const char *>(&array[0]), size * sizeof(array[0]));
|
||||
}
|
||||
|
||||
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
|
||||
{
|
||||
auto & array = this->data(place).array;
|
||||
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
size_t old_size = array.size();
|
||||
array.resize(old_size + size);
|
||||
buf.read(reinterpret_cast<char *>(&array[old_size]), size * sizeof(array[0]));
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
/// Сортировка массива не будет считаться нарушением константности.
|
||||
auto & array = const_cast<typename AggregateFunctionQuantileExactData<T>::Array &>(this->data(place).array);
|
||||
|
||||
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
|
||||
|
||||
size_t num_levels = levels.size();
|
||||
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + num_levels);
|
||||
|
||||
typename ColumnVector<T>::Container_t & data_to = static_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
|
||||
if (!array.empty())
|
||||
{
|
||||
size_t prev_n = 0;
|
||||
for (const auto & level : levels)
|
||||
{
|
||||
size_t n = level < 1
|
||||
? level * array.size()
|
||||
: (array.size() - 1);
|
||||
|
||||
std::nth_element(array.begin() + prev_n, array.begin() + n, array.end());
|
||||
|
||||
data_to.push_back(array[n]);
|
||||
prev_n = n;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < num_levels; ++i)
|
||||
data_to.push_back(T());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -733,7 +733,7 @@ template <typename ArgumentFieldType, typename WeightFieldType>
|
||||
class AggregateFunctionQuantilesTimingWeighted final : public IAggregateFunctionHelper<QuantileTiming>
|
||||
{
|
||||
private:
|
||||
typedef std::vector<double> Levels;
|
||||
using Levels = std::vector<double>;
|
||||
Levels levels;
|
||||
|
||||
public:
|
||||
|
@ -60,6 +60,7 @@ void registerAggregateFunctionCount(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsQuantile(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsQuantileExact(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsQuantileDeterministic(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsQuantileTiming(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory);
|
||||
@ -90,6 +91,7 @@ AggregateFunctionFactory::AggregateFunctionFactory()
|
||||
registerAggregateFunctionSum(*this);
|
||||
registerAggregateFunctionsUniq(*this);
|
||||
registerAggregateFunctionUniqUpTo(*this);
|
||||
registerAggregateFunctionsQuantileExact(*this);
|
||||
}
|
||||
|
||||
|
||||
|
@ -0,0 +1,66 @@
|
||||
#include <DB/AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <DB/AggregateFunctions/Helpers.h>
|
||||
#include <DB/AggregateFunctions/AggregateFunctionQuantileExact.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionQuantileExact(const std::string & name, const DataTypes & argument_types)
|
||||
{
|
||||
if (argument_types.size() != 1)
|
||||
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
const IDataType & argument_type = *argument_types[0];
|
||||
|
||||
if (typeid_cast<const DataTypeUInt8 *>(&argument_type)) return new AggregateFunctionQuantileExact<UInt8>;
|
||||
else if (typeid_cast<const DataTypeUInt16 *>(&argument_type)) return new AggregateFunctionQuantileExact<UInt16>;
|
||||
else if (typeid_cast<const DataTypeUInt32 *>(&argument_type)) return new AggregateFunctionQuantileExact<UInt32>;
|
||||
else if (typeid_cast<const DataTypeUInt64 *>(&argument_type)) return new AggregateFunctionQuantileExact<UInt64>;
|
||||
else if (typeid_cast<const DataTypeInt8 *>(&argument_type)) return new AggregateFunctionQuantileExact<Int8>;
|
||||
else if (typeid_cast<const DataTypeInt16 *>(&argument_type)) return new AggregateFunctionQuantileExact<Int16>;
|
||||
else if (typeid_cast<const DataTypeInt32 *>(&argument_type)) return new AggregateFunctionQuantileExact<Int32>;
|
||||
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionQuantileExact<Int64>;
|
||||
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionQuantileExact<Float32>;
|
||||
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionQuantileExact<Float64>;
|
||||
else if (typeid_cast<const DataTypeDate *>(&argument_type)) return new AggregateFunctionQuantileExact<DataTypeDate::FieldType>;
|
||||
else if (typeid_cast<const DataTypeDateTime*>(&argument_type)) return new AggregateFunctionQuantileExact<DataTypeDateTime::FieldType>;
|
||||
else
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionQuantilesExact(const std::string & name, const DataTypes & argument_types)
|
||||
{
|
||||
if (argument_types.size() != 1)
|
||||
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
const IDataType & argument_type = *argument_types[0];
|
||||
|
||||
if (typeid_cast<const DataTypeUInt8 *>(&argument_type)) return new AggregateFunctionQuantilesExact<UInt8>;
|
||||
else if (typeid_cast<const DataTypeUInt16 *>(&argument_type)) return new AggregateFunctionQuantilesExact<UInt16>;
|
||||
else if (typeid_cast<const DataTypeUInt32 *>(&argument_type)) return new AggregateFunctionQuantilesExact<UInt32>;
|
||||
else if (typeid_cast<const DataTypeUInt64 *>(&argument_type)) return new AggregateFunctionQuantilesExact<UInt64>;
|
||||
else if (typeid_cast<const DataTypeInt8 *>(&argument_type)) return new AggregateFunctionQuantilesExact<Int8>;
|
||||
else if (typeid_cast<const DataTypeInt16 *>(&argument_type)) return new AggregateFunctionQuantilesExact<Int16>;
|
||||
else if (typeid_cast<const DataTypeInt32 *>(&argument_type)) return new AggregateFunctionQuantilesExact<Int32>;
|
||||
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionQuantilesExact<Int64>;
|
||||
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionQuantilesExact<Float32>;
|
||||
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionQuantilesExact<Float64>;
|
||||
else if (typeid_cast<const DataTypeDate *>(&argument_type)) return new AggregateFunctionQuantilesExact<DataTypeDate::FieldType>;
|
||||
else if (typeid_cast<const DataTypeDateTime*>(&argument_type)) return new AggregateFunctionQuantilesExact<DataTypeDateTime::FieldType>;
|
||||
else
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsQuantileExact(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction({"quantileExact", "medianExact"}, createAggregateFunctionQuantileExact);
|
||||
factory.registerFunction({"quantilesExact"}, createAggregateFunctionQuantilesExact);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user