mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-21 17:20:50 +00:00
avgWeighted
This commit is contained in:
parent
7bcebe2742
commit
b676a292d4
@ -21,26 +21,26 @@ namespace ErrorCodes
|
||||
template <typename T>
|
||||
struct AggregateFunctionAvgData
|
||||
{
|
||||
T sum = 0;
|
||||
UInt64 count = 0;
|
||||
T numerator = 0;
|
||||
UInt64 denominator = 0;
|
||||
|
||||
template <typename ResultT>
|
||||
ResultT NO_SANITIZE_UNDEFINED result() const
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<ResultT>)
|
||||
if constexpr (std::numeric_limits<ResultT>::is_iec559)
|
||||
return static_cast<ResultT>(sum) / count; /// allow division by zero
|
||||
return static_cast<ResultT>(numerator) / denominator; /// allow division by zero
|
||||
|
||||
if (count == 0)
|
||||
if (denominator == 0)
|
||||
return static_cast<ResultT>(0);
|
||||
return static_cast<ResultT>(sum / count);
|
||||
return static_cast<ResultT>(numerator / denominator);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Calculates arithmetic mean of numbers.
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>
|
||||
template <typename Data, typename T, typename F>
|
||||
class AggregateFunctionAvgBase : public IAggregateFunctionDataHelper<Data, F>
|
||||
{
|
||||
public:
|
||||
using ResultType = std::conditional_t<IsDecimalNumber<T>, T, Float64>;
|
||||
@ -49,19 +49,17 @@ public:
|
||||
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<Float64>>;
|
||||
|
||||
/// ctor for native types
|
||||
AggregateFunctionAvg(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {})
|
||||
AggregateFunctionAvgBase(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, F>(argument_types_, {})
|
||||
, scale(0)
|
||||
{}
|
||||
|
||||
/// ctor for Decimals
|
||||
AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {})
|
||||
AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, F>(argument_types_, {})
|
||||
, scale(getDecimalScale(data_type))
|
||||
{}
|
||||
|
||||
String getName() const override { return "avg"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
@ -70,29 +68,22 @@ public:
|
||||
return std::make_shared<ResultDataType>();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & column = static_cast<const ColVecType &>(*columns[0]);
|
||||
this->data(place).sum += column.getData()[row_num];
|
||||
++this->data(place).count;
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).sum += this->data(rhs).sum;
|
||||
this->data(place).count += this->data(rhs).count;
|
||||
this->data(place).numerator += this->data(rhs).numerator;
|
||||
this->data(place).denominator += this->data(rhs).denominator;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
writeBinary(this->data(place).sum, buf);
|
||||
writeVarUInt(this->data(place).count, buf);
|
||||
writeBinary(this->data(place).numerator, buf);
|
||||
writeVarUInt(this->data(place).denominator, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||
{
|
||||
readBinary(this->data(place).sum, buf);
|
||||
readVarUInt(this->data(place).count, buf);
|
||||
readBinary(this->data(place).numerator, buf);
|
||||
readVarUInt(this->data(place).denominator, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
@ -103,9 +94,33 @@ public:
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
UInt32 scale;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<Data, T, AggregateFunctionAvg<T, Data>>
|
||||
{
|
||||
public:
|
||||
|
||||
AggregateFunctionAvg(const DataTypes & argument_types_)
|
||||
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvg<T, Data>>(argument_types_) {}
|
||||
|
||||
AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types_)
|
||||
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvg<T, Data>>(data_type, argument_types_)
|
||||
{}
|
||||
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & column = static_cast<const ColVecType &>(*columns[0]);
|
||||
this->data(place).numerator += column.getData()[row_num];
|
||||
++this->data(place).denominator;
|
||||
}
|
||||
|
||||
String getName() const override { return "avg"; }
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
47
dbms/src/AggregateFunctions/AggregateFunctionAvgWeighted.cpp
Normal file
47
dbms/src/AggregateFunctions/AggregateFunctionAvgWeighted.cpp
Normal file
@ -0,0 +1,47 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionAvgWeighted.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct AvgWeighted
|
||||
{
|
||||
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
|
||||
using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgWeightedData<FieldType>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using AggregateFuncAvgWeighted = typename AvgWeighted<T>::Function;
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertBinary(name, argument_types);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (isDecimal(data_type))
|
||||
res.reset(createWithDecimalType<AggregateFuncAvgWeighted>(*data_type, *data_type, argument_types));
|
||||
else
|
||||
res.reset(createWithNumericType<AggregateFuncAvgWeighted>(*data_type, argument_types));
|
||||
|
||||
if (!res)
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("AvgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
54
dbms/src/AggregateFunctions/AggregateFunctionAvgWeighted.h
Normal file
54
dbms/src/AggregateFunctions/AggregateFunctionAvgWeighted.h
Normal file
@ -0,0 +1,54 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionAvg.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionAvgWeightedData
|
||||
{
|
||||
T numerator = 0;
|
||||
T denominator = 0;
|
||||
|
||||
template <typename ResultT>
|
||||
ResultT NO_SANITIZE_UNDEFINED result() const
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<ResultT>)
|
||||
if constexpr (std::numeric_limits<ResultT>::is_iec559)
|
||||
return static_cast<ResultT>(numerator) / denominator; /// allow division by zero
|
||||
|
||||
if (denominator == 0)
|
||||
return static_cast<ResultT>(0);
|
||||
return static_cast<ResultT>(numerator / denominator);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<Data, T, AggregateFunctionAvgWeighted<T, Data>>
|
||||
{
|
||||
public:
|
||||
|
||||
AggregateFunctionAvgWeighted(const DataTypes & argument_types_)
|
||||
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvgWeighted<T, Data>>(argument_types_) {}
|
||||
|
||||
AggregateFunctionAvgWeighted(const IDataType & data_type, const DataTypes & argument_types_)
|
||||
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvgWeighted<T, Data>>(data_type, argument_types_)
|
||||
{}
|
||||
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & values = static_cast<const ColVecType &>(*columns[0]);
|
||||
const auto & weights = static_cast<const ColVecType &>(*columns[1]);
|
||||
|
||||
this->data(place).numerator += values.getData()[row_num] * weights.getData()[row_num];
|
||||
this->data(place).denominator += weights.getData()[row_num];
|
||||
}
|
||||
|
||||
String getName() const override { return "avgWeighted"; }
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -8,6 +8,7 @@ namespace DB
|
||||
{
|
||||
|
||||
void registerAggregateFunctionAvg(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionCount(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionGroupArray(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory &);
|
||||
@ -51,6 +52,7 @@ void registerAggregateFunctions()
|
||||
auto & factory = AggregateFunctionFactory::instance();
|
||||
|
||||
registerAggregateFunctionAvg(factory);
|
||||
registerAggregateFunctionAvgWeighted(factory);
|
||||
registerAggregateFunctionCount(factory);
|
||||
registerAggregateFunctionGroupArray(factory);
|
||||
registerAggregateFunctionGroupUniqArray(factory);
|
||||
|
@ -0,0 +1 @@
|
||||
3
|
1
dbms/tests/queries/0_stateless/01035_avg_weighted.sql
Normal file
1
dbms/tests/queries/0_stateless/01035_avg_weighted.sql
Normal file
@ -0,0 +1 @@
|
||||
SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]) AS t));
|
Loading…
Reference in New Issue
Block a user