Merge pull request #3034 from CurtizJ/CLICKHOUSE-3723.2

CLICKHOUSE-3723 Multiply aggregate states. Fix and optimize #2527.
This commit is contained in:
alexey-milovidov 2018-09-06 03:49:07 +03:00 committed by GitHub
commit f77cd9950c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 109 additions and 0 deletions

View File

@ -5,10 +5,12 @@
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeInterval.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/Native.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
@ -1145,6 +1147,14 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
/// Special case when multiply aggregate function state
if (isAggregateMultiply(arguments[0], arguments[1]))
{
if (checkDataType<DataTypeAggregateFunction>(arguments[0].get()))
return arguments[0];
return arguments[1];
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
{
@ -1195,8 +1205,72 @@ public:
return type_res;
}
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
return false;
auto is_uint_type = [](const DataTypePtr & type)
{
return checkDataType<DataTypeUInt8>(type.get()) || checkDataType<DataTypeUInt16>(type.get())
|| checkDataType<DataTypeUInt32>(type.get()) || checkDataType<DataTypeUInt64>(type.get());
};
return ((checkDataType<DataTypeAggregateFunction>(type0.get()) && is_uint_type(type1))
|| (is_uint_type(type0) && checkDataType<DataTypeAggregateFunction>(type1.get())));
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
/// Special case when multiply aggregate function state
if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{
ColumnNumbers new_arguments = arguments;
if (checkDataType<DataTypeAggregateFunction>(block.getByPosition(new_arguments[1]).type.get()))
std::swap(new_arguments[0], new_arguments[1]);
const ColumnAggregateFunction * column = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(new_arguments[0]).column.get());
IAggregateFunction * function = column->getAggregateFunction().get();
auto arena = std::make_shared<Arena>();
auto column_to = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_to->reserve(input_rows_count);
auto column_from = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_from->reserve(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
column_to->insertDefault();
column_from->insertFrom(column->getData()[i]);
}
auto & vec_to = column_to->getData();
auto & vec_from = column_from->getData();
UInt64 m = block.getByPosition(new_arguments[1]).column->getUInt(0);
/// We use exponentiation by squaring algorithm to perform multiplying aggregate states by N in O(log(N)) operations
/// https://en.wikipedia.org/wiki/Exponentiation_by_squaring
while (m)
{
if (m % 2)
{
for (size_t i = 0; i < input_rows_count; ++i)
function->merge(vec_to[i], vec_from[i], arena.get());
--m;
}
else
{
for (size_t i = 0; i < input_rows_count; ++i)
function->merge(vec_from[i], vec_from[i], arena.get());
m /= 2;
}
}
block.getByPosition(result).column = std::move(column_to);
return;
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{

View File

@ -0,0 +1,13 @@
2
0
33
2
0
18
20 4
2
3
2
[1,1,1,1,1]
[1,1]
[1]

View File

@ -0,0 +1,22 @@
SELECT countMerge(x) AS y FROM ( SELECT countState() * 2 AS x FROM ( SELECT 1 ));
SELECT countMerge(x) AS y FROM ( SELECT countState() * 0 AS x FROM ( SELECT 1 UNION ALL SELECT 2));
SELECT sumMerge(y) AS z FROM ( SELECT sumState(x) * 11 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));
SELECT countMerge(x) AS y FROM ( SELECT 2 * countState() AS x FROM ( SELECT 1 ));
SELECT countMerge(x) AS y FROM ( SELECT 0 * countState() AS x FROM ( SELECT 1 UNION ALL SELECT 2));
SELECT sumMerge(y) AS z FROM ( SELECT 3 * sumState(x) * 2 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));
DROP TABLE IF EXISTS test.mult_aggregation;
CREATE TABLE test.mult_aggregation(a UInt32, b UInt32) ENGINE = Memory;
INSERT INTO test.mult_aggregation VALUES(1, 1);
INSERT INTO test.mult_aggregation VALUES(1, 3);
SELECT sumMerge(x * 5), sumMerge(x) FROM (SELECT sumState(b) AS x FROM test.mult_aggregation);
SELECT uniqMerge(x * 10) FROM (SELECT uniqState(b) AS x FROM test.mult_aggregation);
SELECT maxMerge(x * 10) FROM (SELECT maxState(b) AS x FROM test.mult_aggregation);
SELECT avgMerge(x * 10) FROM (SELECT avgState(b) AS x FROM test.mult_aggregation);
SELECT groupArrayMerge(y * 5) FROM (SELECT groupArrayState(x) AS y FROM (SELECT 1 AS x));
SELECT groupArrayMerge(2)(y * 5) FROM (SELECT groupArrayState(2)(x) AS y FROM (SELECT 1 AS x));
SELECT groupUniqArrayMerge(y * 5) FROM (SELECT groupUniqArrayState(x) AS y FROM (SELECT 1 AS x));
DROP TABLE IF EXISTS test.mult_aggregation;