mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-04 05:22:17 +00:00
108 lines
3.6 KiB
C++
108 lines
3.6 KiB
C++
#include <memory>
|
|
#include <type_traits>
|
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
|
#include <AggregateFunctions/AggregateFunctionAvgWeighted.h>
|
|
#include <AggregateFunctions/Helpers.h>
|
|
#include <AggregateFunctions/FactoryHelpers.h>
|
|
|
|
namespace DB
|
|
{
|
|
struct Settings;
|
|
|
|
namespace ErrorCodes
|
|
{
|
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
|
}
|
|
|
|
namespace
|
|
{
|
|
bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
|
|
{
|
|
const WhichDataType l_dt(left), r_dt(right);
|
|
|
|
constexpr auto allow = [](WhichDataType t)
|
|
{
|
|
return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
|
|
};
|
|
|
|
return allow(l_dt) && allow(r_dt);
|
|
}
|
|
|
|
#define AT_SWITCH(LINE) \
|
|
switch (which.idx) \
|
|
{ \
|
|
LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
|
|
LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
|
|
LINE(Decimal32); LINE(Decimal64); LINE(Decimal128); LINE(Decimal256); \
|
|
LINE(Float32); LINE(Float64); \
|
|
default: return nullptr; \
|
|
}
|
|
|
|
template <class First, class ... TArgs>
|
|
IAggregateFunction * create(const IDataType & second_type, TArgs && ... args)
|
|
{
|
|
const WhichDataType which(second_type);
|
|
|
|
#define LINE(Type) \
|
|
case TypeIndex::Type: return new AggregateFunctionAvgWeighted<First, Type>(std::forward<TArgs>(args)...)
|
|
AT_SWITCH(LINE)
|
|
#undef LINE
|
|
}
|
|
|
|
// Not using helper functions because there are no templates for binary decimal/numeric function.
|
|
template <class... TArgs>
|
|
IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
|
|
{
|
|
const WhichDataType which(first_type);
|
|
|
|
#define LINE(Type) \
|
|
case TypeIndex::Type: return create<Type, TArgs...>(second_type, std::forward<TArgs>(args)...)
|
|
AT_SWITCH(LINE)
|
|
#undef LINE
|
|
}
|
|
|
|
AggregateFunctionPtr
|
|
createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
|
{
|
|
assertNoParameters(name, parameters);
|
|
assertBinary(name, argument_types);
|
|
|
|
const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
|
|
const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
|
|
|
|
if (!allowTypes(data_type, data_type_weight))
|
|
throw Exception(
|
|
"Types " + data_type->getName() +
|
|
" and " + data_type_weight->getName() +
|
|
" are non-conforming as arguments for aggregate function " + name,
|
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
|
|
|
AggregateFunctionPtr ptr;
|
|
|
|
const bool left_decimal = isDecimal(data_type);
|
|
const bool right_decimal = isDecimal(data_type_weight);
|
|
|
|
/// We multiply value by weight, so actual scale of numerator is <scale of value> + <scale of weight>
|
|
if (left_decimal && right_decimal)
|
|
ptr.reset(create(*data_type, *data_type_weight,
|
|
argument_types,
|
|
getDecimalScale(*data_type) + getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
|
|
else if (left_decimal)
|
|
ptr.reset(create(*data_type, *data_type_weight, argument_types,
|
|
getDecimalScale(*data_type)));
|
|
else if (right_decimal)
|
|
ptr.reset(create(*data_type, *data_type_weight, argument_types,
|
|
getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
|
|
else
|
|
ptr.reset(create(*data_type, *data_type_weight, argument_types));
|
|
|
|
return ptr;
|
|
}
|
|
}
|
|
|
|
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
|
|
{
|
|
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
|
|
}
|
|
}
|