Add an optimization that removes redundant equality

checks on boolean functions. This fixes a bug in
which the primary index is not used for queries like
SELECT * FROM <table> WHERE <pk> in (<n>) = 1
This commit is contained in:
Joshua Hildred 2024-04-02 05:24:16 -07:00
parent f36ae13f97
commit 9d4f1d890e
4 changed files with 194 additions and 0 deletions

View File

@ -19,6 +19,19 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
using namespace std::literals;
static constexpr std::array boolean_functions{
"equals"sv, "notEquals"sv, "less"sv, "greaterOrEquals"sv, "greater"sv, "lessOrEquals"sv, "in"sv, "notIn"sv,
"globalIn"sv, "globalNotIn"sv, "nullIn"sv, "notNullIn"sv, "globalNullIn"sv, "globalNullNotIn"sv, "isNull"sv, "isNotNull"sv,
"like"sv, "notLike"sv, "ilike"sv, "notILike"sv, "empty"sv, "notEmpty"sv, "not"sv, "and"sv,
"or"sv};
static bool isBooleanFunction(const String & func_name)
{
return std::any_of(
boolean_functions.begin(), boolean_functions.end(), [&](const auto boolean_func) { return func_name == boolean_func; });
}
/// Visitor that optimizes logical expressions _only_ in JOIN ON section
class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithContext<JoinOnLogicalExpressionOptimizerVisitor>
{
@ -253,6 +266,12 @@ public:
tryOptimizeAndEqualsNotEqualsChain(node);
return;
}
if (function_node->getFunctionName() == "equals")
{
tryOptimizeOutRedundantEquals(node);
return;
}
}
private:
@ -552,6 +571,63 @@ private:
function_node.getArguments().getNodes() = std::move(or_operands);
function_node.resolveAsFunction(or_function_resolver);
}
void tryOptimizeOutRedundantEquals(QueryTreeNodePtr & node)
{
auto & function_node = node->as<FunctionNode &>();
assert(function_node.getFunctionName() == "equals");
bool lhs_const;
bool maybe_invert;
const ConstantNode * constant;
const FunctionNode * child_function;
const auto function_arguments = function_node.getArguments().getNodes();
if (function_arguments.size() != 2)
return;
const auto & lhs = function_arguments[0];
const auto & rhs = function_arguments[1];
if ((constant = lhs->as<ConstantNode>()))
lhs_const = true;
else if ((constant = rhs->as<ConstantNode>()))
lhs_const = false;
else
return;
UInt64 val;
if (!constant->getValue().tryGet<UInt64>(val))
return;
if (val == 1)
maybe_invert = false;
else if (val == 0)
maybe_invert = true;
else
return;
if (lhs_const)
child_function = rhs->as<FunctionNode>();
else
child_function = lhs->as<FunctionNode>();
if (!child_function || !isBooleanFunction(child_function->getFunctionName()))
return;
if (maybe_invert)
{
auto not_resolver = FunctionFactory::instance().get("not", getContext());
const auto not_node = std::make_shared<FunctionNode>("not");
auto & arguments = not_node->getArguments().getNodes();
arguments.reserve(1);
arguments.push_back(lhs_const ? rhs : lhs);
not_node->resolveAsFunction(not_resolver->build(not_node->getArgumentColumns()));
node = not_node;
}
else
node = lhs_const ? rhs : lhs;
}
};
void LogicalExpressionOptimizerPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context)

View File

@ -96,6 +96,18 @@ namespace DB
*
* SELECT * FROM t1 JOIN t2 ON a <=> b
* -------------------------------
*
* 7. Remove redundant equality checks on boolean functions.
* - these requndant checks cause the primary index to not be used when if the query involves any primary key columns
* -------------------------------
* SELECT * FROM t1 WHERE a IN (n) = 1
* SELECT * FROM t1 WHERE a IN (n) = 0
*
* will be transformed into
*
* SELECT * FROM t1 WHERE a IN (n)
* SELECT * FROM t1 WHERE NOT a IN (n)
* -------------------------------
*/
class LogicalExpressionOptimizerPass final : public IQueryTreePass

View File

@ -0,0 +1,23 @@
100
100
100
100
100
100
0
0
0
1
100
101
100
101
100
101
100
1
1
1
1
1
1

View File

@ -0,0 +1,83 @@
DROP TABLE IF EXISTS test_table;
CREATE TABLE test_table
(
k UInt64,
)
ENGINE = MergeTree
ORDER BY k;
INSERT INTO test_table SELECT number FROM numbers(10000000);
SELECT * FROM test_table WHERE k in (100) = 1;
SELECT * FROM test_table WHERE k = (100) = 1;
SELECT * FROM test_table WHERE k not in (100) = 0;
SELECT * FROM test_table WHERE k != (100) = 0;
SELECT * FROM test_table WHERE 1 = (k = 100);
SELECT * FROM test_table WHERE 0 = (k not in (100));
SELECT * FROM test_table WHERE k < 1 = 1;
SELECT * FROM test_table WHERE k >= 1 = 0;
SELECT * FROM test_table WHERE k > 1 = 0;
SELECT * FROM test_table WHERE ((k not in (101) = 0) OR (k in (100) = 1)) = 1;
SELECT * FROM test_table WHERE (NOT ((k not in (100) = 0) OR (k in (100) = 1))) = 0;
SELECT * FROM test_table WHERE (NOT ((k in (101) = 0) OR (k in (100) = 1))) = 1;
SELECT * FROM test_table WHERE ((k not in (101) = 0) OR (k in (100) = 1)) = 1;
SELECT * FROM test_table WHERE ((k not in (99) = 1) AND (k in (100) = 1)) = 1;
SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k in (100) = 1
)
WHERE
explain LIKE '%Granules: 1/%';
SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k >= 1 = 0
)
WHERE
explain LIKE '%Granules: 1/%';
SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k not in (100) = 0
)
WHERE
explain LIKE '%Granules: 1/%';
SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k > 1 = 0
)
WHERE
explain LIKE '%Granules: 1/%';
SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE (NOT ((k not in (100) = 0) OR (k in (100) = 1))) = 0
)
WHERE
explain LIKE '%Granules: 1/%';
SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE (NOT ((k in (101) = 0) OR (k in (100) = 1))) = 1
)
WHERE
explain LIKE '%Granules: 1/%';
DROP TABLE test_table;