Added aggregate function sumKahan [#CLICKHOUSE-2].

This commit is contained in:
Alexey Milovidov 2017-12-23 02:26:30 +03:00
parent 59dce5187a
commit 609133ea01
5 changed files with 24 additions and 29 deletions

View File

@ -10,25 +10,23 @@ namespace DB
namespace
{
template <typename T>
using AggregateFunctionSumSimple = AggregateFunctionSum<T, typename NearestFieldType<T>::Type, AggregateFunctionSumData<typename NearestFieldType<T>::Type>>;
template <typename T>
using AggregateFunctionSumWithOverflow = AggregateFunctionSum<T, T, AggregateFunctionSumData<T>>;
template <typename T>
using AggregateFunctionSumKahan = AggregateFunctionSum<T, Float64, AggregateFunctionSumKahanData<Float64>>;
template <template <typename> class Function>
AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericTypeNearest<AggregateFunctionSum>(*argument_types[0]));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
AggregateFunctionPtr createAggregateFunctionSumWithOverflow(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionSum>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<Function>(*argument_types[0]));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -40,8 +38,9 @@ AggregateFunctionPtr createAggregateFunctionSumWithOverflow(const std::string &
void registerAggregateFunctionSum(AggregateFunctionFactory & factory)
{
factory.registerFunction("sum", createAggregateFunctionSum, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("sumWithOverflow", createAggregateFunctionSumWithOverflow);
factory.registerFunction("sum", createAggregateFunctionSum<AggregateFunctionSumSimple>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("sumWithOverflow", createAggregateFunctionSum<AggregateFunctionSumWithOverflow>);
factory.registerFunction("sumKahan", createAggregateFunctionSum<AggregateFunctionSumKahan>);
}
}

View File

@ -92,7 +92,7 @@ struct AggregateFunctionSumKahanData
/// Counts the sum of the numbers.
template <typename T, typename TResult, typename Data>
class AggregateFunctionSum final : public IAggregateFunctionDataHelper<Data<TResult>, AggregateFunctionSum<T, TResult, Data>>
class AggregateFunctionSum final : public IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>
{
public:
String getName() const override { return "sum"; }

View File

@ -74,18 +74,6 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
}
template <template <typename, typename> class AggregateFunctionTemplate, typename ... TArgs>
static IAggregateFunction * createWithNumericTypeNearest(const IDataType & argument_type, TArgs && ... args)
{
#define DISPATCH(FIELDTYPE, DATATYPE) \
if (typeid_cast<const DATATYPE *>(&argument_type)) \
return new AggregateFunctionTemplate<FIELDTYPE, NearestFieldType<FIELDTYPE>::Type>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES_AND_ENUMS(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> class Data, typename ... TArgs>
static IAggregateFunction * createWithUnsignedIntegerType(const IDataType & argument_type, TArgs && ... args)
{

View File

@ -0,0 +1,4 @@
1000
232
0
1

View File

@ -0,0 +1,4 @@
SELECT sum(1) FROM (SELECT * FROM system.numbers LIMIT 1000);
SELECT sumWithOverflow(1) FROM (SELECT * FROM system.numbers LIMIT 1000);
SELECT sumKahan(1e100) - 1e100 * 1000 FROM (SELECT * FROM system.numbers LIMIT 1000);
SELECT abs(sum(1e100) - 1e100 * 1000) > 1 FROM (SELECT * FROM system.numbers LIMIT 1000);