shift arguments, some problem seems found!

This commit is contained in:
Sergei Tsetlin (rekub) 2018-06-19 16:36:53 +03:00
parent 32557a36c5
commit 987644e1e7
3 changed files with 39 additions and 10 deletions

View File

@ -780,12 +780,33 @@ class FunctionBinaryArithmetic : public IFunction
return castType(left, [&](const auto & left) { return castType(right, [&](const auto & right) { return f(left, right); }); }); return castType(left, [&](const auto & left) { return castType(right, [&](const auto & right) { return f(left, right); }); });
} }
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1, bool & shift) const
{ {
return std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>> return std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>
&& checkDataType<DataTypeAggregateFunction>(type0.get()) &&
&& (checkDataType<DataTypeUInt8>(type1.get()) (
|| checkDataType<DataTypeUInt16>(type1.get())); (
checkDataType<DataTypeAggregateFunction>(type0.get())
&&
(
checkDataType<DataTypeUInt8>(type1.get())
|| checkDataType<DataTypeUInt16>(type1.get())
)
&&
!(shift = false)
)
||
(
checkDataType<DataTypeAggregateFunction>(type1.get())
&&
(
checkDataType<DataTypeUInt8>(type0.get())
|| checkDataType<DataTypeUInt16>(type0.get())
)
&&
(shift = true)
)
);
} }
FunctionBuilderPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const FunctionBuilderPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
@ -845,8 +866,9 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{ {
/// Special case when multiply aggregate function state /// Special case when multiply aggregate function state
if (isAggregateMultiply(arguments[0], arguments[1])) bool shift;
return arguments[0]; if (isAggregateMultiply(arguments[0], arguments[1], shift))
return arguments[shift?1:0];
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1])) if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
@ -888,12 +910,13 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{ {
if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type)) bool shift;
if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type, shift))
{ {
auto c = block.getByPosition(arguments[0]).column->cloneEmpty(); auto c = block.getByPosition(arguments[shift?1:0]).column->cloneEmpty();
size_t m = block.getByPosition(arguments[1]).column->getUInt(0); size_t m = block.getByPosition(arguments[shift?0:1]).column->getUInt(0);
for (size_t i = 0; i < m; ++i) for (size_t i = 0; i < m; ++i)
c->insertRangeFrom(*(block.getByPosition(arguments[0]).column.get()), 0, input_rows_count); c->insertRangeFrom(*(block.getByPosition(arguments[shift?1:0]).column.get()), 0, input_rows_count);
block.getByPosition(result).column = std::move(c); block.getByPosition(result).column = std::move(c);
return; return;
} }

View File

@ -1,3 +1,6 @@
2 2
0 0
33 33
2
0
18

View File

@ -1,3 +1,6 @@
SELECT countMerge(x) AS y FROM ( SELECT countState() * 2 AS x FROM ( SELECT 1 )); SELECT countMerge(x) AS y FROM ( SELECT countState() * 2 AS x FROM ( SELECT 1 ));
SELECT countMerge(x) AS y FROM ( SELECT countState() * 0 AS x FROM ( SELECT 1 UNION ALL SELECT 2)); SELECT countMerge(x) AS y FROM ( SELECT countState() * 0 AS x FROM ( SELECT 1 UNION ALL SELECT 2));
SELECT sumMerge(y) AS z FROM ( SELECT sumState(x) * 11 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x)); SELECT sumMerge(y) AS z FROM ( SELECT sumState(x) * 11 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));
SELECT countMerge(x) AS y FROM ( SELECT 2 * countState() AS x FROM ( SELECT 1 ));
SELECT countMerge(x) AS y FROM ( SELECT 0 * countState() AS x FROM ( SELECT 1 UNION ALL SELECT 2));
SELECT sumMerge(y) AS z FROM ( SELECT 3 * sumState(x) * 2 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));