Fixed bad code with arithmetic ops on aggregate function states

This commit is contained in:
Alexey Milovidov 2019-03-23 05:33:11 +03:00
parent ba474ab21a
commit 847abfdbb1
6 changed files with 65 additions and 20 deletions

View File

@ -547,21 +547,27 @@ class FunctionBinaryArithmetic : public IFunction
throw Exception{"Illegal column " + block.getByPosition(new_arguments[1]).column->getName()
+ " of argument of aggregation state multiply. Should be integer constant", ErrorCodes::ILLEGAL_COLUMN};
const ColumnAggregateFunction * column = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(new_arguments[0]).column.get());
IAggregateFunction * function = column->getAggregateFunction().get();
const IColumn & agg_state_column = *block.getByPosition(new_arguments[0]).column;
bool agg_state_is_const = agg_state_column.isColumnConst();
const ColumnAggregateFunction & column = typeid_cast<const ColumnAggregateFunction &>(
agg_state_is_const ? static_cast<const ColumnConst &>(agg_state_column).getDataColumn() : agg_state_column);
AggregateFunctionPtr function = column.getAggregateFunction();
auto arena = std::make_shared<Arena>();
auto column_to = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_to->reserve(input_rows_count);
size_t size = agg_state_is_const ? 1 : input_rows_count;
auto column_from = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_from->reserve(input_rows_count);
auto column_to = ColumnAggregateFunction::create(function, Arenas(1, arena));
column_to->reserve(size);
for (size_t i = 0; i < input_rows_count; ++i)
auto column_from = ColumnAggregateFunction::create(function, Arenas(1, arena));
column_from->reserve(size);
for (size_t i = 0; i < size; ++i)
{
column_to->insertDefault();
column_from->insertFrom(column->getData()[i]);
column_from->insertFrom(column.getData()[i]);
}
auto & vec_to = column_to->getData();
@ -575,38 +581,55 @@ class FunctionBinaryArithmetic : public IFunction
{
if (m % 2)
{
for (size_t i = 0; i < input_rows_count; ++i)
for (size_t i = 0; i < size; ++i)
function->merge(vec_to[i], vec_from[i], arena.get());
--m;
}
else
{
for (size_t i = 0; i < input_rows_count; ++i)
for (size_t i = 0; i < size; ++i)
function->merge(vec_from[i], vec_from[i], arena.get());
m /= 2;
}
}
block.getByPosition(result).column = std::move(column_to);
if (agg_state_is_const)
block.getByPosition(result).column = ColumnConst::create(std::move(column_to), input_rows_count);
else
block.getByPosition(result).column = std::move(column_to);
}
/// Merge two aggregation states together.
void executeAggregateAddition(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const
{
const ColumnAggregateFunction * columns[2];
for (size_t i = 0; i < 2; ++i)
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
const IColumn & lhs_column = *block.getByPosition(arguments[0]).column;
const IColumn & rhs_column = *block.getByPosition(arguments[1]).column;
auto column_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction());
column_to->reserve(input_rows_count);
bool lhs_is_const = lhs_column.isColumnConst();
bool rhs_is_const = rhs_column.isColumnConst();
for (size_t i = 0; i < input_rows_count; ++i)
const ColumnAggregateFunction & lhs = typeid_cast<const ColumnAggregateFunction &>(
lhs_is_const ? static_cast<const ColumnConst &>(lhs_column).getDataColumn() : lhs_column);
const ColumnAggregateFunction & rhs = typeid_cast<const ColumnAggregateFunction &>(
rhs_is_const ? static_cast<const ColumnConst &>(rhs_column).getDataColumn() : rhs_column);
AggregateFunctionPtr function = lhs.getAggregateFunction();
size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count;
auto column_to = ColumnAggregateFunction::create(function);
column_to->reserve(size);
for (size_t i = 0; i < size; ++i)
{
column_to->insertFrom(columns[0]->getData()[i]);
column_to->insertMergeFrom(columns[1]->getData()[i]);
column_to->insertFrom(lhs.getData()[lhs_is_const ? 0 : i]);
column_to->insertMergeFrom(rhs.getData()[rhs_is_const ? 0 : i]);
}
block.getByPosition(result).column = std::move(column_to);
if (lhs_is_const && rhs_is_const)
block.getByPosition(result).column = ColumnConst::create(std::move(column_to), input_rows_count);
else
block.getByPosition(result).column = std::move(column_to);
}
void executeDateTimeIntervalPlusMinus(Block & block, const ColumnNumbers & arguments,

View File

@ -43,6 +43,8 @@ public:
return 1;
}
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());

View File

@ -0,0 +1,6 @@
90
90
90
90
90
90

View File

@ -0,0 +1,6 @@
SELECT finalizeAggregation((SELECT sumState(number) FROM numbers(10)) + (SELECT sumState(number) FROM numbers(10)));
SELECT finalizeAggregation((SELECT sumState(number) FROM numbers(10)) + materialize((SELECT sumState(number) FROM numbers(10))));
SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10))) + (SELECT sumState(number) FROM numbers(10)));
SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10))) + materialize((SELECT sumState(number) FROM numbers(10))));
SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10)) + (SELECT sumState(number) FROM numbers(10))));
SELECT materialize(finalizeAggregation((SELECT sumState(number) FROM numbers(10)) + (SELECT sumState(number) FROM numbers(10))));

View File

@ -0,0 +1,4 @@
450
450
450
450

View File

@ -0,0 +1,4 @@
SELECT finalizeAggregation((SELECT sumState(number) FROM numbers(10)) * 10);
SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10))) * 10);
SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10)) * 10));
SELECT materialize(finalizeAggregation((SELECT sumState(number) FROM numbers(10)) * 10));