diff --git a/src/AggregateFunctions/AggregateFunctionSumCount.h b/src/AggregateFunctions/AggregateFunctionSumCount.h index 9892c230e1d..63c1704cd7f 100644 --- a/src/AggregateFunctions/AggregateFunctionSumCount.h +++ b/src/AggregateFunctions/AggregateFunctionSumCount.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace DB @@ -12,6 +13,8 @@ class AggregateFunctionSumCount final : public AggregateFunctionAvgBase, UInt64, AggregateFunctionSumCount>; + using Numerator = typename Base::Numerator; + using ColVecType = ColumnVectorOrDecimal; AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0) : Base(argument_types_, num_scale_), scale(num_scale_) {} @@ -33,10 +36,60 @@ public: void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final { - this->data(place).numerator += static_cast &>(*columns[0]).getData()[row_num]; + this->data(place).numerator += static_cast(*columns[0]).getData()[row_num]; ++this->data(place).denominator; } + void addBatchSinglePlace( + size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena *, ssize_t if_argument_pos) const override + { + AggregateFunctionSumData sum_data; + const auto & column = assert_cast(*columns[0]); + if (if_argument_pos >= 0) + { + const auto & flags = assert_cast(*columns[if_argument_pos]).getData(); + sum_data.addManyConditional(column.getData().data(), flags.data(), batch_size); + for (size_t i = 0; i < batch_size; i++) + this->data(place).denominator += (flags[i] != 0); + } + else + { + sum_data.addMany(column.getData().data(), batch_size); + this->data(place).denominator += batch_size; + } + this->data(place).numerator += sum_data.sum; + } + + void addBatchSinglePlaceNotNull( + size_t batch_size, AggregateDataPtr place, const IColumn ** columns, const UInt8 * null_map, Arena *, ssize_t if_argument_pos) + const override + { + AggregateFunctionSumData sum_data; + const auto & column = assert_cast(*columns[0]); + if (if_argument_pos >= 0) + { + /// Merge the 2 sets of flags (null and if) into a single one. This allows us to use parallelizable sums when available + const auto * if_flags = assert_cast(*columns[if_argument_pos]).getData().data(); + auto final_flags = std::make_unique(batch_size); + size_t used_value = 0; + for (size_t i = 0; i < batch_size; ++i) + { + final_flags[i] = (!null_map[i]) & if_flags[i]; + used_value += (!null_map[i]) & if_flags[i]; + } + + sum_data.addManyConditional(column.getData().data(), final_flags.get(), batch_size); + this->data(place).denominator += used_value; + } + else + { + sum_data.addManyNotNull(column.getData().data(), null_map, batch_size); + for (size_t i = 0; i < batch_size; i++) + this->data(place).denominator += (!null_map[i]); + } + this->data(place).numerator += sum_data.sum; + } + String getName() const final { return "sumCount"; } #if USE_EMBEDDED_COMPILER