dbms: added weighted variants of t-digest [#METR-19024].

This commit is contained in:
Alexey Milovidov 2015-11-21 16:24:51 +03:00
parent 57e8a8fdbb
commit 31c09b4d2d
2 changed files with 141 additions and 7 deletions

View File

@ -11,6 +11,7 @@
#include <DB/Common/RadixSort.h>
#include <DB/Common/PODArray.h>
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
#include <DB/AggregateFunctions/IBinaryAggregateFunction.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
@ -289,13 +290,14 @@ struct AggregateFunctionQuantileTDigestData
};
template <typename T>
template <typename T, bool returns_float = true>
class AggregateFunctionQuantileTDigest final
: public IUnaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantileTDigest<T>>
{
private:
double level;
tdigest::Params<Float32> params;
DataTypePtr type;
public:
AggregateFunctionQuantileTDigest(double level_ = 0.5) : level(level_) {}
@ -304,11 +306,15 @@ public:
DataTypePtr getReturnType() const override
{
return new DataTypeFloat64;
return type;
}
void setArgument(const DataTypePtr & argument)
{
if (returns_float)
type = new DataTypeFloat32;
else
type = argument;
}
void setParameters(const Array & params) override
@ -341,8 +347,81 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
static_cast<ColumnFloat64 &>(to).getData().push_back(
this->data(const_cast<AggregateDataPtr>(place)).digest.quantile(params, level));
auto quantile = this->data(const_cast<AggregateDataPtr>(place)).digest.quantile(params, level);
if (returns_float)
static_cast<ColumnFloat32 &>(to).getData().push_back(quantile);
else
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
}
};
template <typename T, typename Weight, bool returns_float = true>
class AggregateFunctionQuantileTDigestWeighted final
: public IBinaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantileTDigestWeighted<T, Weight>>
{
private:
double level;
tdigest::Params<Float32> params;
DataTypePtr type;
public:
AggregateFunctionQuantileTDigestWeighted(double level_ = 0.5) : level(level_) {}
String getName() const override { return "quantileTDigestWeighted"; }
DataTypePtr getReturnType() const override
{
return type;
}
void setArgumentsImpl(const DataTypes & arguments)
{
if (returns_float)
type = new DataTypeFloat32;
else
type = arguments.at(0);
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
level = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
}
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
{
this->data(place).digest.add(params,
static_cast<const ColumnVector<T> &>(column_value).getData()[row_num],
static_cast<const ColumnVector<Weight> &>(column_weight).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).digest.merge(params, this->data(rhs).digest);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(const_cast<AggregateDataPtr>(place)).digest.write(params, buf);
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
this->data(place).digest.readAndMerge(params, buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto quantile = this->data(const_cast<AggregateDataPtr>(place)).digest.quantile(params, level);
if (returns_float)
static_cast<ColumnFloat32 &>(to).getData().push_back(quantile);
else
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
}
};

View File

@ -25,10 +25,64 @@ AggregateFunctionPtr createAggregateFunctionQuantileTDigest(const std::string &
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionQuantileTDigest<Int64>;
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionQuantileTDigest<Float32>;
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionQuantileTDigest<Float64>;
/* else if (typeid_cast<const DataTypeDate *>(&argument_type)) return new AggregateFunctionQuantile<DataTypeDate::FieldType, false>;
else if (typeid_cast<const DataTypeDateTime*>(&argument_type)) return new AggregateFunctionQuantile<DataTypeDateTime::FieldType, false>;*/
else if (typeid_cast<const DataTypeDate *>(&argument_type)) return new AggregateFunctionQuantileTDigest<DataTypeDate::FieldType, false>;
else if (typeid_cast<const DataTypeDateTime*>(&argument_type)) return new AggregateFunctionQuantileTDigest<DataTypeDateTime::FieldType, false>;
else
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("Illegal type " + argument_type.getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
template <typename T, bool returns_float>
AggregateFunctionPtr createAggregateFunctionQuantileTDigestWeightedImpl(const std::string & name, const DataTypes & argument_types)
{
const IDataType & argument_type = *argument_types[1];
if (typeid_cast<const DataTypeUInt8 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt8, returns_float>;
else if (typeid_cast<const DataTypeUInt16 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt16, returns_float>;
else if (typeid_cast<const DataTypeUInt32 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt32, returns_float>;
else if (typeid_cast<const DataTypeUInt64 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, UInt64, returns_float>;
else if (typeid_cast<const DataTypeInt8 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int8, returns_float>;
else if (typeid_cast<const DataTypeInt16 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int16, returns_float>;
else if (typeid_cast<const DataTypeInt32 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int32, returns_float>;
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Int64, returns_float>;
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Float32, returns_float>;
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionQuantileTDigestWeighted<T, Float64, returns_float>;
else
throw Exception("Illegal type " + argument_type.getName() + " of second argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
AggregateFunctionPtr createAggregateFunctionQuantileTDigestWeighted(const std::string & name, const DataTypes & argument_types)
{
if (argument_types.size() != 2)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const IDataType & argument_type = *argument_types[0];
if (typeid_cast<const DataTypeUInt8 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt8, true>(name, argument_types);
else if (typeid_cast<const DataTypeUInt16 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt16, true>(name, argument_types);
else if (typeid_cast<const DataTypeUInt32 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt32, true>(name, argument_types);
else if (typeid_cast<const DataTypeUInt64 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<UInt64, true>(name, argument_types);
else if (typeid_cast<const DataTypeInt8 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<Int8, true>(name, argument_types);
else if (typeid_cast<const DataTypeInt16 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<Int16, true>(name, argument_types);
else if (typeid_cast<const DataTypeInt32 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<Int32, true>(name, argument_types);
else if (typeid_cast<const DataTypeInt64 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<Int64, true>(name, argument_types);
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<Float32, true>(name, argument_types);
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<Float64, true>(name, argument_types);
else if (typeid_cast<const DataTypeDate *>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<DataTypeDate::FieldType, false>(name, argument_types);
else if (typeid_cast<const DataTypeDateTime*>(&argument_type))
return createAggregateFunctionQuantileTDigestWeightedImpl<DataTypeDateTime::FieldType, false>(name, argument_types);
else
throw Exception("Illegal type " + argument_type.getName() + " of first argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
/*
@ -60,6 +114,7 @@ AggregateFunctionPtr createAggregateFunctionQuantilesTDigest(const std::string &
void registerAggregateFunctionsQuantileTDigest(AggregateFunctionFactory & factory)
{
factory.registerFunction({"quantileTDigest", "medianTDigest"}, createAggregateFunctionQuantileTDigest);
factory.registerFunction({"quantileTDigestWeighted", "medianTDigestWeighted"}, createAggregateFunctionQuantileTDigestWeighted);
// factory.registerFunction({"quantilesTDigest"}, createAggregateFunctionQuantilesTDigest);
}