mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-27 10:02:01 +00:00
Merge pull request #43543 from kitaisreal/analyzer-sum-if-to-count-if-fix
Analyzer SumIfToCountIfPass fix
This commit is contained in:
commit
c46a659ad9
@ -61,7 +61,7 @@ public:
|
|||||||
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);
|
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);
|
||||||
function_node_arguments_nodes.resize(1);
|
function_node_arguments_nodes.resize(1);
|
||||||
|
|
||||||
resolveAggregateFunctionNode(*function_node, "countIf");
|
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,15 +102,16 @@ public:
|
|||||||
function_node_arguments_nodes[0] = std::move(nested_if_function_arguments_nodes[0]);
|
function_node_arguments_nodes[0] = std::move(nested_if_function_arguments_nodes[0]);
|
||||||
function_node_arguments_nodes.resize(1);
|
function_node_arguments_nodes.resize(1);
|
||||||
|
|
||||||
resolveAggregateFunctionNode(*function_node, "countIf");
|
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))`.
|
/// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))`.
|
||||||
if (if_true_condition_value == 0 && if_false_condition_value == 1)
|
if (if_true_condition_value == 0 && if_false_condition_value == 1)
|
||||||
{
|
{
|
||||||
auto condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
|
|
||||||
DataTypePtr not_function_result_type = std::make_shared<DataTypeUInt8>();
|
DataTypePtr not_function_result_type = std::make_shared<DataTypeUInt8>();
|
||||||
|
|
||||||
|
const auto & condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
|
||||||
if (condition_result_type->isNullable())
|
if (condition_result_type->isNullable())
|
||||||
not_function_result_type = makeNullable(not_function_result_type);
|
not_function_result_type = makeNullable(not_function_result_type);
|
||||||
|
|
||||||
@ -123,23 +124,21 @@ public:
|
|||||||
function_node_arguments_nodes[0] = std::move(not_function);
|
function_node_arguments_nodes[0] = std::move(not_function);
|
||||||
function_node_arguments_nodes.resize(1);
|
function_node_arguments_nodes.resize(1);
|
||||||
|
|
||||||
resolveAggregateFunctionNode(*function_node, "countIf");
|
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
|
static inline void resolveAsCountIfAggregateFunction(FunctionNode & function_node, const DataTypePtr & argument_type)
|
||||||
{
|
{
|
||||||
auto function_result_type = function_node.getResultType();
|
|
||||||
auto function_aggregate_function = function_node.getAggregateFunction();
|
|
||||||
|
|
||||||
AggregateFunctionProperties properties;
|
AggregateFunctionProperties properties;
|
||||||
auto aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name,
|
auto aggregate_function = AggregateFunctionFactory::instance().get("countIf",
|
||||||
function_aggregate_function->getArgumentTypes(),
|
{argument_type},
|
||||||
function_aggregate_function->getParameters(),
|
function_node.getAggregateFunction()->getParameters(),
|
||||||
properties);
|
properties);
|
||||||
|
|
||||||
|
auto function_result_type = function_node.getResultType();
|
||||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,77 @@
|
|||||||
|
QUERY id: 0
|
||||||
|
PROJECTION COLUMNS
|
||||||
|
sumIf(1, equals(modulo(number, 2), 0)) UInt64
|
||||||
|
PROJECTION
|
||||||
|
LIST id: 1, nodes: 1
|
||||||
|
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 3, nodes: 1
|
||||||
|
FUNCTION id: 4, function_name: equals, function_type: ordinary, result_type: UInt8
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 5, nodes: 2
|
||||||
|
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 7, nodes: 2
|
||||||
|
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
|
||||||
|
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
|
||||||
|
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
|
||||||
|
JOIN TREE
|
||||||
|
TABLE_FUNCTION id: 9, table_function_name: numbers
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 12, nodes: 1
|
||||||
|
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
|
||||||
|
--
|
||||||
|
5
|
||||||
|
--
|
||||||
|
QUERY id: 0
|
||||||
|
PROJECTION COLUMNS
|
||||||
|
sum(if(equals(modulo(number, 2), 0), 1, 0)) UInt64
|
||||||
|
PROJECTION
|
||||||
|
LIST id: 1, nodes: 1
|
||||||
|
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 3, nodes: 1
|
||||||
|
FUNCTION id: 4, function_name: equals, function_type: ordinary, result_type: UInt8
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 5, nodes: 2
|
||||||
|
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 7, nodes: 2
|
||||||
|
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
|
||||||
|
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
|
||||||
|
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
|
||||||
|
JOIN TREE
|
||||||
|
TABLE_FUNCTION id: 9, table_function_name: numbers
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 12, nodes: 1
|
||||||
|
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
|
||||||
|
--
|
||||||
|
5
|
||||||
|
--
|
||||||
|
QUERY id: 0
|
||||||
|
PROJECTION COLUMNS
|
||||||
|
sum(if(equals(modulo(number, 2), 0), 0, 1)) UInt64
|
||||||
|
PROJECTION
|
||||||
|
LIST id: 1, nodes: 1
|
||||||
|
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 3, nodes: 1
|
||||||
|
FUNCTION id: 4, function_name: not, function_type: ordinary, result_type: UInt8
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 5, nodes: 1
|
||||||
|
FUNCTION id: 6, function_name: equals, function_type: ordinary, result_type: UInt8
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 7, nodes: 2
|
||||||
|
FUNCTION id: 8, function_name: modulo, function_type: ordinary, result_type: UInt8
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 9, nodes: 2
|
||||||
|
COLUMN id: 10, column_name: number, result_type: UInt64, source_id: 11
|
||||||
|
CONSTANT id: 12, constant_value: UInt64_2, constant_value_type: UInt8
|
||||||
|
CONSTANT id: 13, constant_value: UInt64_0, constant_value_type: UInt8
|
||||||
|
JOIN TREE
|
||||||
|
TABLE_FUNCTION id: 11, table_function_name: numbers
|
||||||
|
ARGUMENTS
|
||||||
|
LIST id: 14, nodes: 1
|
||||||
|
CONSTANT id: 15, constant_value: UInt64_10, constant_value_type: UInt8
|
||||||
|
--
|
||||||
|
5
|
@ -0,0 +1,24 @@
|
|||||||
|
SET allow_experimental_analyzer = 1;
|
||||||
|
SET optimize_rewrite_sum_if_to_count_if = 1;
|
||||||
|
|
||||||
|
EXPLAIN QUERY TREE (SELECT sumIf(1, (number % 2) == 0) FROM numbers(10));
|
||||||
|
|
||||||
|
SELECT '--';
|
||||||
|
|
||||||
|
SELECT sumIf(1, (number % 2) == 0) FROM numbers(10);
|
||||||
|
|
||||||
|
SELECT '--';
|
||||||
|
|
||||||
|
EXPLAIN QUERY TREE (SELECT sum(if((number % 2) == 0, 1, 0)) FROM numbers(10));
|
||||||
|
|
||||||
|
SELECT '--';
|
||||||
|
|
||||||
|
SELECT sum(if((number % 2) == 0, 1, 0)) FROM numbers(10);
|
||||||
|
|
||||||
|
SELECT '--';
|
||||||
|
|
||||||
|
EXPLAIN QUERY TREE (SELECT sum(if((number % 2) == 0, 0, 1)) FROM numbers(10));
|
||||||
|
|
||||||
|
SELECT '--';
|
||||||
|
|
||||||
|
SELECT sum(if((number % 2) == 0, 0, 1)) FROM numbers(10);
|
Loading…
Reference in New Issue
Block a user