This commit is contained in:
Alexey Milovidov 2018-09-08 04:42:34 +03:00
parent 8dcf59389c
commit 38b812ddba

View File

@ -1147,22 +1147,22 @@ class FunctionBinaryArithmetic : public IFunction
{ {
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, PlusImpl<UInt8, UInt8>>) if constexpr (!std::is_same_v<Op<UInt8, UInt8>, PlusImpl<UInt8, UInt8>>)
return false; return false;
WhichDataType which0(type0); WhichDataType which0(type0);
WhichDataType which1(type1); WhichDataType which1(type1);
return which0.isAggregateFunction() && which1.isAggregateFunction(); return which0.isAggregateFunction() && which1.isAggregateFunction();
} }
/// Multiply aggregation state by integer constant: by merging it with itself specified number of times. /// Multiply aggregation state by integer constant: by merging it with itself specified number of times.
void executeAggregateMultiply(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const void executeAggregateMultiply(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const
{ {
ColumnNumbers new_arguments = arguments; ColumnNumbers new_arguments = arguments;
if (WhichDataType(block.getByPosition(new_arguments[1]).type).isAggregateFunction()) if (WhichDataType(block.getByPosition(new_arguments[1]).type).isAggregateFunction())
std::swap(new_arguments[0], new_arguments[1]); std::swap(new_arguments[0], new_arguments[1]);
if (!block.getByPosition(new_arguments[1]).column->isColumnConst()) if (!block.getByPosition(new_arguments[1]).column->isColumnConst())
throw Exception{"Illegal column " + block.getByPosition(new_arguments[1]).column->getName() throw Exception{"Illegal column " + block.getByPosition(new_arguments[1]).column->getName()
+ " of argument of aggregation state multiply. Should be integer constant", ErrorCodes::ILLEGAL_COLUMN}; + " of argument of aggregation state multiply. Should be integer constant", ErrorCodes::ILLEGAL_COLUMN};
const ColumnAggregateFunction * column = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(new_arguments[0]).column.get()); const ColumnAggregateFunction * column = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(new_arguments[0]).column.get());
@ -1216,7 +1216,7 @@ class FunctionBinaryArithmetic : public IFunction
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get()); columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
auto column_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction()); auto column_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction());
column_to->reserve(input_rows_count); column_to->reserve(input_rows_count);
for(size_t i = 0; i < input_rows_count; ++i) for(size_t i = 0; i < input_rows_count; ++i)
{ {
@ -1227,7 +1227,7 @@ class FunctionBinaryArithmetic : public IFunction
block.getByPosition(result).column = std::move(column_to); block.getByPosition(result).column = std::move(column_to);
} }
void executeDateTimeIntervalPlusMinus(Block & block, const ColumnNumbers & arguments, void executeDateTimeIntervalPlusMinus(Block & block, const ColumnNumbers & arguments,
size_t result, size_t input_rows_count, const FunctionBuilderPtr & function_builder) const size_t result, size_t input_rows_count, const FunctionBuilderPtr & function_builder) const
{ {
ColumnNumbers new_arguments = arguments; ColumnNumbers new_arguments = arguments;
@ -1278,7 +1278,7 @@ public:
if (isAggregateAddition(arguments[0], arguments[1])) if (isAggregateAddition(arguments[0], arguments[1]))
{ {
if (!arguments[0]->equals(*arguments[1])) if (!arguments[0]->equals(*arguments[1]))
throw Exception("Cannot add aggregate states of different functions: " throw Exception("Cannot add aggregate states of different functions: "
+ arguments[0]->getName() + " and " + arguments[1]->getName(), ErrorCodes::CANNOT_ADD_DIFFERENT_AGGREGATE_STATES); + arguments[0]->getName() + " and " + arguments[1]->getName(), ErrorCodes::CANNOT_ADD_DIFFERENT_AGGREGATE_STATES);
return arguments[0]; return arguments[0];
@ -1334,18 +1334,6 @@ public:
return type_res; return type_res;
} }
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
return false;
WhichDataType which0(type0);
WhichDataType which1(type1);
return (which0.isAggregateFunction() && which1.isNativeUInt())
|| (which0.isNativeUInt() && which1.isAggregateFunction());
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{ {
/// Special case when multiply aggregate function state /// Special case when multiply aggregate function state
@ -1358,7 +1346,7 @@ public:
/// Special case - addition of two aggregate functions states /// Special case - addition of two aggregate functions states
if (isAggregateAddition(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type)) if (isAggregateAddition(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{ {
executeAggregateAddition(block, arguments, result, input_rows_count); executeAggregateAddition(block, arguments, result, input_rows_count);
return; return;
} }