diff --git a/src/Interpreters/LogicalExpressionsOptimizer.cpp b/src/Interpreters/LogicalExpressionsOptimizer.cpp index 9e30cac2e19..bd24e13b129 100644 --- a/src/Interpreters/LogicalExpressionsOptimizer.cpp +++ b/src/Interpreters/LogicalExpressionsOptimizer.cpp @@ -1,13 +1,17 @@ #include +#include +#include #include #include #include #include +#include #include #include +#include #include @@ -32,8 +36,9 @@ bool LogicalExpressionsOptimizer::OrWithExpression::operator<(const OrWithExpres return std::tie(this->or_function, this->expression) < std::tie(rhs.or_function, rhs.expression); } -LogicalExpressionsOptimizer::LogicalExpressionsOptimizer(ASTSelectQuery * select_query_, UInt64 optimize_min_equality_disjunction_chain_length) - : select_query(select_query_), settings(optimize_min_equality_disjunction_chain_length) +LogicalExpressionsOptimizer::LogicalExpressionsOptimizer(ASTSelectQuery * select_query_, + const TablesWithColumns & tables_with_columns_, UInt64 optimize_min_equality_disjunction_chain_length) + : select_query(select_query_), tables_with_columns(tables_with_columns_), settings(optimize_min_equality_disjunction_chain_length) { } @@ -196,13 +201,39 @@ inline ASTs & getFunctionOperands(const ASTFunction * or_function) } +bool LogicalExpressionsOptimizer::isLowCardinalityEqualityChain(const std::vector & functions) const +{ + if (functions.size() > 1) + { + /// Check if identifier is LowCardinality type + auto & first_operands = getFunctionOperands(functions[0]); + const auto * identifier = first_operands[0]->as(); + if (identifier) + { + auto pos = IdentifierSemantic::getMembership(*identifier); + if (!pos) + pos = IdentifierSemantic::chooseTableColumnMatch(*identifier, tables_with_columns, true); + if (pos) + { + if (auto data_type_and_name = tables_with_columns[*pos].columns.tryGetByName(identifier->shortName())) + { + if (typeid_cast(data_type_and_name->type.get())) + return true; + } + } + } + } + return false; +} + bool LogicalExpressionsOptimizer::mayOptimizeDisjunctiveEqualityChain(const DisjunctiveEqualityChain & chain) const { const auto & equalities = chain.second; const auto & equality_functions = equalities.functions; /// We eliminate too short chains. - if (equality_functions.size() < settings.optimize_min_equality_disjunction_chain_length) + if (equality_functions.size() < settings.optimize_min_equality_disjunction_chain_length && + !isLowCardinalityEqualityChain(equality_functions)) return false; /// We check that the right-hand sides of all equalities have the same type. diff --git a/src/Interpreters/LogicalExpressionsOptimizer.h b/src/Interpreters/LogicalExpressionsOptimizer.h index 4991d31f8b1..a8a0d186394 100644 --- a/src/Interpreters/LogicalExpressionsOptimizer.h +++ b/src/Interpreters/LogicalExpressionsOptimizer.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -36,7 +37,7 @@ class LogicalExpressionsOptimizer final public: /// Constructor. Accepts the root of the query DAG. - LogicalExpressionsOptimizer(ASTSelectQuery * select_query_, UInt64 optimize_min_equality_disjunction_chain_length); + LogicalExpressionsOptimizer(ASTSelectQuery * select_query_, const TablesWithColumns & tables_with_columns_, UInt64 optimize_min_equality_disjunction_chain_length); /** Replace all rather long homogeneous OR-chains expr = x1 OR ... OR expr = xN * on the expressions `expr` IN (x1, ..., xN). @@ -79,6 +80,9 @@ private: */ bool mayOptimizeDisjunctiveEqualityChain(const DisjunctiveEqualityChain & chain) const; + /// Check if is LowCardinality OR chain + bool isLowCardinalityEqualityChain(const std::vector & functions) const; + /// Insert the IN expression into the OR chain. static void addInExpression(const DisjunctiveEqualityChain & chain); @@ -96,6 +100,7 @@ private: using ColumnToPosition = std::unordered_map; ASTSelectQuery * select_query; + const TablesWithColumns & tables_with_columns; const ExtractedSettings settings; /// Information about the OR-chains inside the query. DisjunctiveEqualityChainsMap disjunctive_equality_chains_map; diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index eb713019306..c61ba9c3286 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -1246,7 +1246,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect( translateQualifiedNames(query, *select_query, source_columns_set, tables_with_columns); /// Optimizes logical expressions. - LogicalExpressionsOptimizer(select_query, settings.optimize_min_equality_disjunction_chain_length.value).perform(); + LogicalExpressionsOptimizer(select_query, tables_with_columns, settings.optimize_min_equality_disjunction_chain_length.value).perform(); NameSet all_source_columns_set = source_columns_set; if (table_join)