diff --git a/dbms/include/DB/Interpreters/ExpressionAnalyzer.h b/dbms/include/DB/Interpreters/ExpressionAnalyzer.h index 65bb03c31d0..5d2c326566a 100644 --- a/dbms/include/DB/Interpreters/ExpressionAnalyzer.h +++ b/dbms/include/DB/Interpreters/ExpressionAnalyzer.h @@ -230,6 +230,11 @@ private: /// Удалить из ORDER BY повторяющиеся элементы. void optimizeOrderBy(); + /// remove Function_if AST if condition is constant + void optimizeIfWithConstantCondition(); + void optimizeIfWithConstantConditionImpl(ASTPtr & current_ast) const; + bool tryExtractConstValueFromCondition(const ASTPtr & condition, bool & value) const; + /// Превратить перечисление значений или подзапрос в ASTSet. node - функция in или notIn. void makeSet(ASTFunction * node, const Block & sample_block); diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index b7614e7de70..f10352fcfca 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -169,6 +169,9 @@ void ExpressionAnalyzer::init() /// Выполнение скалярных подзапросов - замена их на значения-константы. executeScalarSubqueries(); + /// Optimize if with constant condition after constats are substituted instead of sclalar subqueries + optimizeIfWithConstantCondition(); + /// GROUP BY injective function elimination. optimizeGroupBy(); @@ -195,6 +198,77 @@ void ExpressionAnalyzer::init() analyzeAggregation(); } +void ExpressionAnalyzer::optimizeIfWithConstantCondition() +{ + optimizeIfWithConstantConditionImpl(ast); +} + +bool ExpressionAnalyzer::tryExtractConstValueFromCondition(const ASTPtr & condition, bool & value) const +{ + /// numeric constant in condition + if (const ASTLiteral * literal = typeid_cast(condition.get())) + { + if (literal->value.getType() == Field::Types::Int64 || + literal->value.getType() == Field::Types::UInt64) + { + value = literal->value.get(); + return true; + } + } + + /// cast of numeric constant in condition to UInt8 + if (const ASTFunction * function = typeid_cast(condition.get())) + { + if (function->name == FunctionCast::name) + { + if (ASTExpressionList * expr_list = typeid_cast(function->arguments.get())) + { + const ASTPtr & type_ast = expr_list->children.at(1); + if (const ASTLiteral * type_literal = typeid_cast(type_ast.get())) + { + if (type_literal->value.getType() == Field::Types::String && + type_literal->value.get() == "UInt8") + return tryExtractConstValueFromCondition(expr_list->children.at(0), value); + } + } + } + } + + return false; +} + +void ExpressionAnalyzer::optimizeIfWithConstantConditionImpl(ASTPtr & current_ast) const +{ + if (!current_ast) + return; + + for (ASTPtr & child : current_ast->children) + { + ASTFunction * function_node = typeid_cast(child.get()); + if (!function_node || function_node->name != FunctionIf::name) + { + optimizeIfWithConstantConditionImpl(child); + continue; + } + + optimizeIfWithConstantConditionImpl(function_node->arguments); + ASTExpressionList * args = typeid_cast(function_node->arguments.get()); + + ASTPtr condition_expr = args->children.at(0); + ASTPtr then_expr = args->children.at(1); + ASTPtr else_expr = args->children.at(2); + + + bool condition; + if (tryExtractConstValueFromCondition(condition_expr, condition)) + { + if (condition) + child = then_expr; + else + child = else_expr; + } + } +} void ExpressionAnalyzer::analyzeAggregation() { diff --git a/dbms/tests/queries/0_stateless/00393_if_with_constant_condition.reference b/dbms/tests/queries/0_stateless/00393_if_with_constant_condition.reference new file mode 100644 index 00000000000..b9a713a016c --- /dev/null +++ b/dbms/tests/queries/0_stateless/00393_if_with_constant_condition.reference @@ -0,0 +1,4 @@ +1 +1 +2 +42 diff --git a/dbms/tests/queries/0_stateless/00393_if_with_constant_condition.sql b/dbms/tests/queries/0_stateless/00393_if_with_constant_condition.sql new file mode 100644 index 00000000000..5efff085b16 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00393_if_with_constant_condition.sql @@ -0,0 +1,5 @@ +SELECT 1 ? 1 : 0; +SELECT 0 ? not_existing_column : 1 FROM system.numbers LIMIT 1; +SELECT if(1, if(0, not_existing_column, 2), 0) FROM system.numbers LIMIT 1; + +SELECT (SELECT hasColumnInTable('system', 'numbers', 'not_existing')) ? not_existing : 42 FROM system.numbers LIMIT 1;