Share code between Avg and SumCount

This commit is contained in:
Raúl Marín 2021-11-24 13:13:54 +01:00
parent b6b75c28dd
commit 5dd3cc6595
2 changed files with 8 additions and 67 deletions

View File

@ -224,7 +224,7 @@ using AvgFieldType = std::conditional_t<is_decimal<T>,
NearestFieldType<T>>; NearestFieldType<T>>;
template <typename T> template <typename T>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>> class AggregateFunctionAvg : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>
{ {
public: public:
using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>; using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>;
@ -242,8 +242,8 @@ public:
++this->data(place).denominator; ++this->data(place).denominator;
} }
void addBatchSinglePlace( void
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena *, ssize_t if_argument_pos) const override addBatchSinglePlace(size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena *, ssize_t if_argument_pos) const final
{ {
AggregateFunctionSumData<Numerator> sum_data; AggregateFunctionSumData<Numerator> sum_data;
const auto & column = assert_cast<const ColVecType &>(*columns[0]); const auto & column = assert_cast<const ColVecType &>(*columns[0]);
@ -264,7 +264,7 @@ public:
void addBatchSinglePlaceNotNull( void addBatchSinglePlaceNotNull(
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, const UInt8 * null_map, Arena *, ssize_t if_argument_pos) size_t batch_size, AggregateDataPtr place, const IColumn ** columns, const UInt8 * null_map, Arena *, ssize_t if_argument_pos)
const override const final
{ {
AggregateFunctionSumData<Numerator> sum_data; AggregateFunctionSumData<Numerator> sum_data;
const auto & column = assert_cast<const ColVecType &>(*columns[0]); const auto & column = assert_cast<const ColVecType &>(*columns[0]);
@ -292,7 +292,7 @@ public:
this->data(place).numerator += sum_data.sum; this->data(place).numerator += sum_data.sum;
} }
String getName() const final { return "avg"; } String getName() const override { return "avg"; }
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER

View File

@ -3,20 +3,17 @@
#include <type_traits> #include <type_traits>
#include <DataTypes/DataTypeTuple.h> #include <DataTypes/DataTypeTuple.h>
#include <AggregateFunctions/AggregateFunctionAvg.h> #include <AggregateFunctions/AggregateFunctionAvg.h>
#include <AggregateFunctions/AggregateFunctionSum.h>
namespace DB namespace DB
{ {
template <typename T> template <typename T>
class AggregateFunctionSumCount final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionSumCount<T>> class AggregateFunctionSumCount final : public AggregateFunctionAvg<T>
{ {
public: public:
using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionSumCount<T>>; using Base = AggregateFunctionAvg<T>;
using Numerator = typename Base::Numerator;
using ColVecType = ColumnVectorOrDecimal<T>;
AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0) explicit AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
: Base(argument_types_, num_scale_), scale(num_scale_) {} : Base(argument_types_, num_scale_), scale(num_scale_) {}
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
@ -34,62 +31,6 @@ public:
this->data(place).denominator); this->data(place).denominator);
} }
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
{
this->data(place).numerator += static_cast<const ColVecType &>(*columns[0]).getData()[row_num];
++this->data(place).denominator;
}
void addBatchSinglePlace(
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena *, ssize_t if_argument_pos) const override
{
AggregateFunctionSumData<Numerator> sum_data;
const auto & column = assert_cast<const ColVecType &>(*columns[0]);
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
sum_data.addManyConditional(column.getData().data(), flags.data(), batch_size);
for (size_t i = 0; i < batch_size; i++)
this->data(place).denominator += (flags[i] != 0);
}
else
{
sum_data.addMany(column.getData().data(), batch_size);
this->data(place).denominator += batch_size;
}
this->data(place).numerator += sum_data.sum;
}
void addBatchSinglePlaceNotNull(
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, const UInt8 * null_map, Arena *, ssize_t if_argument_pos)
const override
{
AggregateFunctionSumData<Numerator> sum_data;
const auto & column = assert_cast<const ColVecType &>(*columns[0]);
if (if_argument_pos >= 0)
{
/// Merge the 2 sets of flags (null and if) into a single one. This allows us to use parallelizable sums when available
const auto * if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData().data();
auto final_flags = std::make_unique<UInt8[]>(batch_size);
size_t used_value = 0;
for (size_t i = 0; i < batch_size; ++i)
{
final_flags[i] = (!null_map[i]) & if_flags[i];
used_value += (!null_map[i]) & if_flags[i];
}
sum_data.addManyConditional(column.getData().data(), final_flags.get(), batch_size);
this->data(place).denominator += used_value;
}
else
{
sum_data.addManyNotNull(column.getData().data(), null_map, batch_size);
for (size_t i = 0; i < batch_size; i++)
this->data(place).denominator += (!null_map[i]);
}
this->data(place).numerator += sum_data.sum;
}
String getName() const final { return "sumCount"; } String getName() const final { return "sumCount"; }
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER