fix cores

This commit is contained in:
taiyang-li 2023-01-29 19:48:46 +08:00
parent 3ee792f4e2
commit eeef2dae77
4 changed files with 128 additions and 7 deletions

View File

@ -72,6 +72,9 @@ class IQueryTreeNode : public TypePromotion<IQueryTreeNode>
public:
virtual ~IQueryTreeNode() = default;
IQueryTreeNode & operator=(const IQueryTreeNode &) = default;
IQueryTreeNode & operator=(IQueryTreeNode &&) = default;
/// Get query tree node type
virtual QueryTreeNodeType getNodeType() const = 0;

View File

@ -58,13 +58,20 @@ public:
if (!isInt64OrUInt64FieldType(constant_value_literal.getType()))
return;
if (constant_value_literal.get<UInt64>() != 1 || getSettings().aggregate_functions_null_for_empty)
if (getSettings().aggregate_functions_null_for_empty)
return;
/// Rewrite `sumIf(1, cond)` into `countIf(cond)`
auto multiplier_node = function_node_arguments_nodes[0];
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);
function_node_arguments_nodes.resize(1);
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
if (constant_value_literal.get<UInt64>() != 1)
{
/// Rewrite `sumIf(123, cond)` into `123 * countIf(cond)`
node = getMultiplyFunction(std::move(multiplier_node), node);
}
return;
}
@ -79,7 +86,7 @@ public:
if (!nested_function || nested_function->getFunctionName() != "if")
return;
const auto & nested_if_function_arguments_nodes = nested_function->getArguments().getNodes();
const auto nested_if_function_arguments_nodes = nested_function->getArguments().getNodes();
if (nested_if_function_arguments_nodes.size() != 3)
return;
@ -100,19 +107,25 @@ public:
auto if_true_condition_value = if_true_condition_constant_value_literal.get<UInt64>();
auto if_false_condition_value = if_false_condition_constant_value_literal.get<UInt64>();
/// Rewrite `sum(if(cond, 1, 0))` into `countIf(cond)`.
if (if_true_condition_value == 1 && if_false_condition_value == 0)
if (if_false_condition_value == 0)
{
/// Rewrite `sum(if(cond, 1, 0))` into `countIf(cond)`.
function_node_arguments_nodes[0] = nested_if_function_arguments_nodes[0];
function_node_arguments_nodes.resize(1);
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
if (if_true_condition_value != 1)
{
/// Rewrite `sum(if(cond, 123, 0))` into `123 * countIf(cond)`.
node = getMultiplyFunction(nested_if_function_arguments_nodes[1], node);
}
return;
}
/// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))` if condition is not Nullable (otherwise the result can be different).
if (if_true_condition_value == 0 && if_false_condition_value == 1 && !cond_argument->getResultType()->isNullable())
if (if_true_condition_value == 0 && !cond_argument->getResultType()->isNullable())
{
/// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))` if condition is not Nullable (otherwise the result can be different).
DataTypePtr not_function_result_type = std::make_shared<DataTypeUInt8>();
const auto & condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
@ -130,6 +143,12 @@ public:
function_node_arguments_nodes.resize(1);
resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
if (if_false_condition_value != 1)
{
/// Rewrite `sum(if(cond, 0, 123))` into `123 * countIf(not(cond))` if condition is not Nullable (otherwise the result can be different).
node = getMultiplyFunction(nested_if_function_arguments_nodes[2], node);
}
return;
}
}
@ -145,6 +164,18 @@ private:
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
}
inline QueryTreeNodePtr getMultiplyFunction(QueryTreeNodePtr left, QueryTreeNodePtr right)
{
auto multiply_function_node = std::make_shared<FunctionNode>("multiply");
auto & multiply_arguments_nodes = multiply_function_node->getArguments().getNodes();
multiply_arguments_nodes.push_back(std::move(left));
multiply_arguments_nodes.push_back(std::move(right));
auto multiply_function_base = FunctionFactory::instance().get("multiply", getContext())->build(multiply_function_node->getArgumentColumns());
multiply_function_node->resolveAsFunction(std::move(multiply_function_base));
return std::move(multiply_function_node);
}
};
}

View File

@ -34,3 +34,84 @@ SELECT 123 * countIf((number % 2) = 0)
FROM numbers(100)
SELECT 123 * countIf(NOT ((number % 2) = 0))
FROM numbers(100)
QUERY id: 0
PROJECTION COLUMNS
sumIf(123, equals(modulo(number, 2), 0)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 2
CONSTANT id: 4, constant_value: UInt64_123, constant_value_type: UInt8
FUNCTION id: 5, function_name: countIf, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 6, nodes: 1
FUNCTION id: 7, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 8, nodes: 2
FUNCTION id: 9, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 10, nodes: 2
COLUMN id: 11, column_name: number, result_type: UInt64, source_id: 12
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 14, constant_value: UInt64_0, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 12, table_function_name: numbers
ARGUMENTS
LIST id: 15, nodes: 1
CONSTANT id: 16, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
sum(if(equals(modulo(number, 2), 0), 123, 0)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 2
CONSTANT id: 4, constant_value: UInt64_123, constant_value_type: UInt8
FUNCTION id: 5, function_name: countIf, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 6, nodes: 1
FUNCTION id: 7, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 8, nodes: 2
FUNCTION id: 9, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 10, nodes: 2
COLUMN id: 11, column_name: number, result_type: UInt64, source_id: 12
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 14, constant_value: UInt64_0, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 12, table_function_name: numbers
ARGUMENTS
LIST id: 15, nodes: 1
CONSTANT id: 16, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
sum(if(equals(modulo(number, 2), 0), 0, 123)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 2
CONSTANT id: 4, constant_value: UInt64_123, constant_value_type: UInt8
FUNCTION id: 5, function_name: countIf, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 6, nodes: 1
FUNCTION id: 7, function_name: not, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 8, nodes: 1
FUNCTION id: 9, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 10, nodes: 2
FUNCTION id: 11, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 12, nodes: 2
COLUMN id: 13, column_name: number, result_type: UInt64, source_id: 14
CONSTANT id: 15, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 16, constant_value: UInt64_0, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 14, table_function_name: numbers
ARGUMENTS
LIST id: 17, nodes: 1
CONSTANT id: 18, constant_value: UInt64_100, constant_value_type: UInt8

View File

@ -41,3 +41,9 @@ SELECT countIf(number % 2 != 0) FROM numbers(100);
EXPLAIN SYNTAX SELECT sumIf(123, number % 2 == 0) FROM numbers(100);
EXPLAIN SYNTAX SELECT sum(if(number % 2 == 0, 123, 0)) FROM numbers(100);
EXPLAIN SYNTAX SELECT sum(if(number % 2 == 0, 0, 123)) FROM numbers(100);
set allow_experimental_analyzer = true;
EXPLAIN QUERY TREE run_passes=1 SELECT sumIf(123, number % 2 == 0) FROM numbers(100);
EXPLAIN QUERY TREE run_passes=1 SELECT sum(if(number % 2 == 0, 123, 0)) FROM numbers(100);
EXPLAIN QUERY TREE run_passes=1 SELECT sum(if(number % 2 == 0, 0, 123)) FROM numbers(100);