mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 08:02:02 +00:00
Merge pull request #43543 from kitaisreal/analyzer-sum-if-to-count-if-fix
Analyzer SumIfToCountIfPass fix
This commit is contained in:
commit
c46a659ad9
@ -61,7 +61,7 @@ public:
|
||||
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);
|
||||
function_node_arguments_nodes.resize(1);
|
||||
|
||||
resolveAggregateFunctionNode(*function_node, "countIf");
|
||||
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
|
||||
return;
|
||||
}
|
||||
|
||||
@ -102,15 +102,16 @@ public:
|
||||
function_node_arguments_nodes[0] = std::move(nested_if_function_arguments_nodes[0]);
|
||||
function_node_arguments_nodes.resize(1);
|
||||
|
||||
resolveAggregateFunctionNode(*function_node, "countIf");
|
||||
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
|
||||
return;
|
||||
}
|
||||
|
||||
/// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))`.
|
||||
if (if_true_condition_value == 0 && if_false_condition_value == 1)
|
||||
{
|
||||
auto condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
|
||||
DataTypePtr not_function_result_type = std::make_shared<DataTypeUInt8>();
|
||||
|
||||
const auto & condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
|
||||
if (condition_result_type->isNullable())
|
||||
not_function_result_type = makeNullable(not_function_result_type);
|
||||
|
||||
@ -123,23 +124,21 @@ public:
|
||||
function_node_arguments_nodes[0] = std::move(not_function);
|
||||
function_node_arguments_nodes.resize(1);
|
||||
|
||||
resolveAggregateFunctionNode(*function_node, "countIf");
|
||||
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
|
||||
static inline void resolveAsCountIfAggregateFunction(FunctionNode & function_node, const DataTypePtr & argument_type)
|
||||
{
|
||||
auto function_result_type = function_node.getResultType();
|
||||
auto function_aggregate_function = function_node.getAggregateFunction();
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name,
|
||||
function_aggregate_function->getArgumentTypes(),
|
||||
function_aggregate_function->getParameters(),
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get("countIf",
|
||||
{argument_type},
|
||||
function_node.getAggregateFunction()->getParameters(),
|
||||
properties);
|
||||
|
||||
auto function_result_type = function_node.getResultType();
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,77 @@
|
||||
QUERY id: 0
|
||||
PROJECTION COLUMNS
|
||||
sumIf(1, equals(modulo(number, 2), 0)) UInt64
|
||||
PROJECTION
|
||||
LIST id: 1, nodes: 1
|
||||
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
|
||||
ARGUMENTS
|
||||
LIST id: 3, nodes: 1
|
||||
FUNCTION id: 4, function_name: equals, function_type: ordinary, result_type: UInt8
|
||||
ARGUMENTS
|
||||
LIST id: 5, nodes: 2
|
||||
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
|
||||
ARGUMENTS
|
||||
LIST id: 7, nodes: 2
|
||||
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
|
||||
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
|
||||
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
|
||||
JOIN TREE
|
||||
TABLE_FUNCTION id: 9, table_function_name: numbers
|
||||
ARGUMENTS
|
||||
LIST id: 12, nodes: 1
|
||||
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
|
||||
--
|
||||
5
|
||||
--
|
||||
QUERY id: 0
|
||||
PROJECTION COLUMNS
|
||||
sum(if(equals(modulo(number, 2), 0), 1, 0)) UInt64
|
||||
PROJECTION
|
||||
LIST id: 1, nodes: 1
|
||||
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
|
||||
ARGUMENTS
|
||||
LIST id: 3, nodes: 1
|
||||
FUNCTION id: 4, function_name: equals, function_type: ordinary, result_type: UInt8
|
||||
ARGUMENTS
|
||||
LIST id: 5, nodes: 2
|
||||
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
|
||||
ARGUMENTS
|
||||
LIST id: 7, nodes: 2
|
||||
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
|
||||
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
|
||||
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
|
||||
JOIN TREE
|
||||
TABLE_FUNCTION id: 9, table_function_name: numbers
|
||||
ARGUMENTS
|
||||
LIST id: 12, nodes: 1
|
||||
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
|
||||
--
|
||||
5
|
||||
--
|
||||
QUERY id: 0
|
||||
PROJECTION COLUMNS
|
||||
sum(if(equals(modulo(number, 2), 0), 0, 1)) UInt64
|
||||
PROJECTION
|
||||
LIST id: 1, nodes: 1
|
||||
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
|
||||
ARGUMENTS
|
||||
LIST id: 3, nodes: 1
|
||||
FUNCTION id: 4, function_name: not, function_type: ordinary, result_type: UInt8
|
||||
ARGUMENTS
|
||||
LIST id: 5, nodes: 1
|
||||
FUNCTION id: 6, function_name: equals, function_type: ordinary, result_type: UInt8
|
||||
ARGUMENTS
|
||||
LIST id: 7, nodes: 2
|
||||
FUNCTION id: 8, function_name: modulo, function_type: ordinary, result_type: UInt8
|
||||
ARGUMENTS
|
||||
LIST id: 9, nodes: 2
|
||||
COLUMN id: 10, column_name: number, result_type: UInt64, source_id: 11
|
||||
CONSTANT id: 12, constant_value: UInt64_2, constant_value_type: UInt8
|
||||
CONSTANT id: 13, constant_value: UInt64_0, constant_value_type: UInt8
|
||||
JOIN TREE
|
||||
TABLE_FUNCTION id: 11, table_function_name: numbers
|
||||
ARGUMENTS
|
||||
LIST id: 14, nodes: 1
|
||||
CONSTANT id: 15, constant_value: UInt64_10, constant_value_type: UInt8
|
||||
--
|
||||
5
|
@ -0,0 +1,24 @@
|
||||
SET allow_experimental_analyzer = 1;
|
||||
SET optimize_rewrite_sum_if_to_count_if = 1;
|
||||
|
||||
EXPLAIN QUERY TREE (SELECT sumIf(1, (number % 2) == 0) FROM numbers(10));
|
||||
|
||||
SELECT '--';
|
||||
|
||||
SELECT sumIf(1, (number % 2) == 0) FROM numbers(10);
|
||||
|
||||
SELECT '--';
|
||||
|
||||
EXPLAIN QUERY TREE (SELECT sum(if((number % 2) == 0, 1, 0)) FROM numbers(10));
|
||||
|
||||
SELECT '--';
|
||||
|
||||
SELECT sum(if((number % 2) == 0, 1, 0)) FROM numbers(10);
|
||||
|
||||
SELECT '--';
|
||||
|
||||
EXPLAIN QUERY TREE (SELECT sum(if((number % 2) == 0, 0, 1)) FROM numbers(10));
|
||||
|
||||
SELECT '--';
|
||||
|
||||
SELECT sum(if((number % 2) == 0, 0, 1)) FROM numbers(10);
|
Loading…
Reference in New Issue
Block a user