mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
shift arguments, some problem seems found!
This commit is contained in:
parent
32557a36c5
commit
987644e1e7
@ -780,12 +780,33 @@ class FunctionBinaryArithmetic : public IFunction
|
||||
return castType(left, [&](const auto & left) { return castType(right, [&](const auto & right) { return f(left, right); }); });
|
||||
}
|
||||
|
||||
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
|
||||
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1, bool & shift) const
|
||||
{
|
||||
return std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>
|
||||
&& checkDataType<DataTypeAggregateFunction>(type0.get())
|
||||
&& (checkDataType<DataTypeUInt8>(type1.get())
|
||||
|| checkDataType<DataTypeUInt16>(type1.get()));
|
||||
&&
|
||||
(
|
||||
(
|
||||
checkDataType<DataTypeAggregateFunction>(type0.get())
|
||||
&&
|
||||
(
|
||||
checkDataType<DataTypeUInt8>(type1.get())
|
||||
|| checkDataType<DataTypeUInt16>(type1.get())
|
||||
)
|
||||
&&
|
||||
!(shift = false)
|
||||
)
|
||||
||
|
||||
(
|
||||
checkDataType<DataTypeAggregateFunction>(type1.get())
|
||||
&&
|
||||
(
|
||||
checkDataType<DataTypeUInt8>(type0.get())
|
||||
|| checkDataType<DataTypeUInt16>(type0.get())
|
||||
)
|
||||
&&
|
||||
(shift = true)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
FunctionBuilderPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
|
||||
@ -845,8 +866,9 @@ public:
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
/// Special case when multiply aggregate function state
|
||||
if (isAggregateMultiply(arguments[0], arguments[1]))
|
||||
return arguments[0];
|
||||
bool shift;
|
||||
if (isAggregateMultiply(arguments[0], arguments[1], shift))
|
||||
return arguments[shift?1:0];
|
||||
|
||||
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
|
||||
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
|
||||
@ -888,12 +910,13 @@ public:
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
|
||||
{
|
||||
if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
|
||||
bool shift;
|
||||
if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type, shift))
|
||||
{
|
||||
auto c = block.getByPosition(arguments[0]).column->cloneEmpty();
|
||||
size_t m = block.getByPosition(arguments[1]).column->getUInt(0);
|
||||
auto c = block.getByPosition(arguments[shift?1:0]).column->cloneEmpty();
|
||||
size_t m = block.getByPosition(arguments[shift?0:1]).column->getUInt(0);
|
||||
for (size_t i = 0; i < m; ++i)
|
||||
c->insertRangeFrom(*(block.getByPosition(arguments[0]).column.get()), 0, input_rows_count);
|
||||
c->insertRangeFrom(*(block.getByPosition(arguments[shift?1:0]).column.get()), 0, input_rows_count);
|
||||
block.getByPosition(result).column = std::move(c);
|
||||
return;
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
2
|
||||
0
|
||||
33
|
||||
2
|
||||
0
|
||||
18
|
||||
|
@ -1,3 +1,6 @@
|
||||
SELECT countMerge(x) AS y FROM ( SELECT countState() * 2 AS x FROM ( SELECT 1 ));
|
||||
SELECT countMerge(x) AS y FROM ( SELECT countState() * 0 AS x FROM ( SELECT 1 UNION ALL SELECT 2));
|
||||
SELECT sumMerge(y) AS z FROM ( SELECT sumState(x) * 11 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));
|
||||
SELECT countMerge(x) AS y FROM ( SELECT 2 * countState() AS x FROM ( SELECT 1 ));
|
||||
SELECT countMerge(x) AS y FROM ( SELECT 0 * countState() AS x FROM ( SELECT 1 UNION ALL SELECT 2));
|
||||
SELECT sumMerge(y) AS z FROM ( SELECT 3 * sumState(x) * 2 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));
|
||||
|
Loading…
Reference in New Issue
Block a user