fix modifying source columns, add tests

This commit is contained in:
CurtizJ 2018-09-05 16:18:47 +03:00
parent 8d8aeb51e5
commit 7ae4c1072b
3 changed files with 39 additions and 8 deletions

View File

@ -909,16 +909,24 @@ public:
const ColumnAggregateFunction * column = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(new_arguments[0]).column.get());
IAggregateFunction * function = column->getAggregateFunction().get();
MutableColumnPtr current = column->cloneEmpty();
auto arena = std::make_shared<Arena>();
auto & res = typeid_cast<ColumnAggregateFunction &>(*current);
auto & vec_to = res.getData();
const auto & vec_from = column->getData();
auto column_to = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_to->reserve(input_rows_count);
auto column_from = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_from->reserve(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
res.insertDefault();
{
column_to->insertDefault();
column_from->insertFrom(column->getData()[i]);
}
size_t m = block.getByPosition(new_arguments[1]).column->getUInt(0);
auto & vec_to = column_to->getData();
auto & vec_from = column_from->getData();
UInt64 m = block.getByPosition(new_arguments[1]).column->getUInt(0);
/// We use exponentiation by squaring algorithm to perform multiplying aggregate states by N in O(log(N)) operations
/// https://en.wikipedia.org/wiki/Exponentiation_by_squaring
@ -938,7 +946,7 @@ public:
}
}
block.getByPosition(result).column = std::move(current);
block.getByPosition(result).column = std::move(column_to);
return;
}

View File

@ -4,3 +4,10 @@
2
0
18
20 4
2
3
2
[1,1,1,1,1]
[1,1]
[1]

View File

@ -4,3 +4,19 @@ SELECT sumMerge(y) AS z FROM ( SELECT sumState(x) * 11 AS y FROM ( SELECT 1 AS x
SELECT countMerge(x) AS y FROM ( SELECT 2 * countState() AS x FROM ( SELECT 1 ));
SELECT countMerge(x) AS y FROM ( SELECT 0 * countState() AS x FROM ( SELECT 1 UNION ALL SELECT 2));
SELECT sumMerge(y) AS z FROM ( SELECT 3 * sumState(x) * 2 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));
DROP TABLE IF EXISTS test.mult_aggregation;
CREATE TABLE test.mult_aggregation(a UInt32, b UInt32) ENGINE = Memory;
INSERT INTO test.mult_aggregation VALUES(1, 1);
INSERT INTO test.mult_aggregation VALUES(1, 3);
SELECT sumMerge(x * 5), sumMerge(x) FROM (SELECT sumState(b) AS x FROM test.mult_aggregation);
SELECT uniqMerge(x * 10) FROM (SELECT uniqState(b) AS x FROM test.mult_aggregation);
SELECT maxMerge(x * 10) FROM (SELECT maxState(b) AS x FROM test.mult_aggregation);
SELECT avgMerge(x * 10) FROM (SELECT avgState(b) AS x FROM test.mult_aggregation);
SELECT groupArrayMerge(y * 5) FROM (SELECT groupArrayState(x) AS y FROM (SELECT 1 AS x));
SELECT groupArrayMerge(2)(y * 5) FROM (SELECT groupArrayState(2)(x) AS y FROM (SELECT 1 AS x));
SELECT groupUniqArrayMerge(y * 5) FROM (SELECT groupUniqArrayState(x) AS y FROM (SELECT 1 AS x));
DROP TABLE IF EXISTS test.mult_aggregation;