fix name, add exception, add tests

This commit is contained in:
Andrei Bodrov 2019-12-15 16:36:44 +03:00
parent d9c35c3242
commit cae40d3d94
4 changed files with 22 additions and 50 deletions

View File

@ -21,7 +21,7 @@ template <typename T>
struct AggregateFunctionAvgData
{
T numerator = 0;
UInt64 denominator = 0;
T denominator = 0;
template <typename ResultT>
ResultT NO_SANITIZE_UNDEFINED result() const
@ -37,8 +37,8 @@ struct AggregateFunctionAvgData
};
/// Calculates arithmetic mean of numbers.
template <typename Data, typename T, typename F>
class AggregateFunctionAvgBase : public IAggregateFunctionDataHelper<Data, F>
template <typename T, typename Data, typename Derived>
class AggregateFunctionAvgBase : public IAggregateFunctionDataHelper<Data, Derived>
{
public:
using ResultType = std::conditional_t<IsDecimalNumber<T>, T, Float64>;
@ -47,11 +47,11 @@ public:
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<Float64>>;
/// ctor for native types
AggregateFunctionAvgBase(const DataTypes & argument_types_) : IAggregateFunctionDataHelper<Data, F>(argument_types_, {}), scale(0) {}
AggregateFunctionAvgBase(const DataTypes & argument_types_) : IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(0) {}
/// ctor for Decimals
AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, F>(argument_types_, {}), scale(getDecimalScale(data_type))
: IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(getDecimalScale(data_type))
{
}
@ -94,25 +94,17 @@ protected:
};
template <typename T, typename Data>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<Data, T, AggregateFunctionAvg<T, Data>>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>
{
public:
AggregateFunctionAvg(const DataTypes & argument_types_)
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvg<T, Data>>(argument_types_)
{
}
AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types_)
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvg<T, Data>>(data_type, argument_types_)
{
}
using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>::AggregateFunctionAvgBase;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & column = static_cast<const ColVecType &>(*columns[0]);
this->data(place).numerator += column.getData()[row_num];
++this->data(place).denominator;
this->data(place).denominator += 1;
}
String getName() const override { return "avg"; }

View File

@ -13,7 +13,7 @@ template <typename T>
struct AvgWeighted
{
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgWeightedData<FieldType>>;
using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgData<FieldType>>;
};
template <typename T>
@ -25,14 +25,18 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
assertBinary(name, argument_types);
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
if (!data_type->equals(*data_type_weight))
throw Exception("Different types " + data_type->getName() + " and " + data_type_weight->getName() + " of arguments for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFuncAvgWeighted>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<AggregateFuncAvgWeighted>(*data_type, argument_types));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
@ -41,7 +45,7 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
{
factory.registerFunction("AvgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
}
}

View File

@ -4,38 +4,11 @@
namespace DB
{
template <typename T>
struct AggregateFunctionAvgWeightedData
{
T numerator = 0;
T denominator = 0;
template <typename ResultT>
ResultT NO_SANITIZE_UNDEFINED result() const
{
if constexpr (std::is_floating_point_v<ResultT>)
if constexpr (std::numeric_limits<ResultT>::is_iec559)
return static_cast<ResultT>(numerator) / denominator; /// allow division by zero
if (denominator == 0)
return static_cast<ResultT>(0);
return static_cast<ResultT>(numerator / denominator);
}
};
template <typename T, typename Data>
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<Data, T, AggregateFunctionAvgWeighted<T, Data>>
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>
{
public:
AggregateFunctionAvgWeighted(const DataTypes & argument_types_)
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvgWeighted<T, Data>>(argument_types_)
{
}
AggregateFunctionAvgWeighted(const IDataType & data_type, const DataTypes & argument_types_)
: AggregateFunctionAvgBase<Data, T, AggregateFunctionAvgWeighted<T, Data>>(data_type, argument_types_)
{
}
using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>::AggregateFunctionAvgBase;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -1 +1,4 @@
SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]) AS t));
SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]) AS t));
SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]) AS t));
SELECT avgWeighted(toDecimal64(0, 0), toFloat64(0))