mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 16:12:01 +00:00
dbms: added weighted variants of t-digest [#METR-19024].
This commit is contained in:
parent
57e8a8fdbb
commit
31c09b4d2d
@ -11,6 +11,7 @@
|
||||
#include <DB/Common/RadixSort.h>
|
||||
#include <DB/Common/PODArray.h>
|
||||
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
|
||||
#include <DB/AggregateFunctions/IBinaryAggregateFunction.h>
|
||||
#include <DB/DataTypes/DataTypesNumberFixed.h>
|
||||
|
||||
|
||||
@ -289,13 +290,14 @@ struct AggregateFunctionQuantileTDigestData
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool returns_float = true>
|
||||
class AggregateFunctionQuantileTDigest final
|
||||
: public IUnaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantileTDigest<T>>
|
||||
{
|
||||
private:
|
||||
double level;
|
||||
tdigest::Params<Float32> params;
|
||||
DataTypePtr type;
|
||||
|
||||
public:
|
||||
AggregateFunctionQuantileTDigest(double level_ = 0.5) : level(level_) {}
|
||||
@ -304,11 +306,15 @@ public:
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return new DataTypeFloat64;
|
||||
return type;
|
||||
}
|
||||
|
||||
void setArgument(const DataTypePtr & argument)
|
||||
{
|
||||
if (returns_float)
|
||||
type = new DataTypeFloat32;
|
||||
else
|
||||
type = argument;
|
||||
}
|
||||
|
||||
void setParameters(const Array & params) override
|
||||
@ -341,8 +347,81 @@ public:
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
static_cast<ColumnFloat64 &>(to).getData().push_back(
|
||||
this->data(const_cast<AggregateDataPtr>(place)).digest.quantile(params, level));
|
||||
auto quantile = this->data(const_cast<AggregateDataPtr>(place)).digest.quantile(params, level);
|
||||
|
||||
if (returns_float)
|
||||
static_cast<ColumnFloat32 &>(to).getData().push_back(quantile);
|
||||
else
|
||||
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename Weight, bool returns_float = true>
|
||||
class AggregateFunctionQuantileTDigestWeighted final
|
||||
: public IBinaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantileTDigestWeighted<T, Weight>>
|
||||
{
|
||||
private:
|
||||
double level;
|
||||
tdigest::Params<Float32> params;
|
||||
DataTypePtr type;
|
||||
|
||||
public:
|
||||
AggregateFunctionQuantileTDigestWeighted(double level_ = 0.5) : level(level_) {}
|
||||
|
||||
String getName() const override { return "quantileTDigestWeighted"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return type;
|
||||
}
|
||||
|
||||
void setArgumentsImpl(const DataTypes & arguments)
|
||||
{
|
||||
if (returns_float)
|
||||
type = new DataTypeFloat32;
|
||||
else
|
||||
type = arguments.at(0);
|
||||
}
|
||||
|
||||
void setParameters(const Array & params) override
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
level = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
|
||||
}
|
||||
|
||||
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
|
||||
{
|
||||
this->data(place).digest.add(params,
|
||||
static_cast<const ColumnVector<T> &>(column_value).getData()[row_num],
|
||||
static_cast<const ColumnVector<Weight> &>(column_weight).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
|
||||
{
|
||||
this->data(place).digest.merge(params, this->data(rhs).digest);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
this->data(const_cast<AggregateDataPtr>(place)).digest.write(params, buf);
|
||||
}
|
||||
|
||||
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
|
||||
{
|
||||
this->data(place).digest.readAndMerge(params, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
auto quantile = this->data(const_cast<AggregateDataPtr>(place)).digest.quantile(params, level);
|
||||
|
||||
if (returns_float)
|
||||
static_cast<ColumnFloat32 &>(to).getData().push_back(quantile);
|
||||
else
|
||||
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -25,10 +25,64 @@ AggregateFunctionPtr createAggregateFunctionQuantileTDigest(const std::string &
|
||||
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionQuantileTDigest<Int64>;
|
||||
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionQuantileTDigest<Float32>;
|
||||
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionQuantileTDigest<Float64>;
|
||||
/* else if (typeid_cast<const DataTypeDate *>(&argument_type)) return new AggregateFunctionQuantile<DataTypeDate::FieldType, false>;
|
||||
else if (typeid_cast<const DataTypeDateTime*>(&argument_type)) return new AggregateFunctionQuantile<DataTypeDateTime::FieldType, false>;*/
|
||||
else if (typeid_cast<const DataTypeDate *>(&argument_type)) return new AggregateFunctionQuantileTDigest<DataTypeDate::FieldType, false>;
|
||||
else if (typeid_cast<const DataTypeDateTime*>(&argument_type)) return new AggregateFunctionQuantileTDigest<DataTypeDateTime::FieldType, false>;
|
||||
else
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
throw Exception("Illegal type " + argument_type.getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
template <typename T, bool returns_float>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantileTDigestWeightedImpl(const std::string & name, const DataTypes & argument_types)
|
||||
{
|
||||
const IDataType & argument_type = *argument_types[1];
|
||||
|
||||
if (typeid_cast<const DataTypeUInt8 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt8, returns_float>;
|
||||
else if (typeid_cast<const DataTypeUInt16 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt16, returns_float>;
|
||||
else if (typeid_cast<const DataTypeUInt32 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt32, returns_float>;
|
||||
else if (typeid_cast<const DataTypeUInt64 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt64, returns_float>;
|
||||
else if (typeid_cast<const DataTypeInt8 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int8, returns_float>;
|
||||
else if (typeid_cast<const DataTypeInt16 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int16, returns_float>;
|
||||
else if (typeid_cast<const DataTypeInt32 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int32, returns_float>;
|
||||
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int64, returns_float>;
|
||||
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Float32, returns_float>;
|
||||
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Float64, returns_float>;
|
||||
else
|
||||
throw Exception("Illegal type " + argument_type.getName() + " of second argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionQuantileTDigestWeighted(const std::string & name, const DataTypes & argument_types)
|
||||
{
|
||||
if (argument_types.size() != 2)
|
||||
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
const IDataType & argument_type = *argument_types[0];
|
||||
|
||||
if (typeid_cast<const DataTypeUInt8 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt8, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeUInt16 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt16, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeUInt32 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt32, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeUInt64 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt64, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeInt8 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<Int8, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeInt16 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<Int16, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeInt32 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<Int32, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeInt64 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<Int64, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<Float32, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<Float64, true>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeDate *>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<DataTypeDate::FieldType, false>(name, argument_types);
|
||||
else if (typeid_cast<const DataTypeDateTime*>(&argument_type))
|
||||
return createAggregateFunctionQuantileTDigestWeightedImpl<DataTypeDateTime::FieldType, false>(name, argument_types);
|
||||
else
|
||||
throw Exception("Illegal type " + argument_type.getName() + " of first argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -60,6 +114,7 @@ AggregateFunctionPtr createAggregateFunctionQuantilesTDigest(const std::string &
|
||||
void registerAggregateFunctionsQuantileTDigest(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction({"quantileTDigest", "medianTDigest"}, createAggregateFunctionQuantileTDigest);
|
||||
factory.registerFunction({"quantileTDigestWeighted", "medianTDigestWeighted"}, createAggregateFunctionQuantileTDigestWeighted);
|
||||
// factory.registerFunction({"quantilesTDigest"}, createAggregateFunctionQuantilesTDigest);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user