From 69559a4fd9f62488532d5ba319f0ee5be38986b4 Mon Sep 17 00:00:00 2001 From: Anton Popov Date: Thu, 18 Nov 2021 17:24:06 +0300 Subject: [PATCH] fix convertion to CNF --- src/Interpreters/TreeCNFConverter.cpp | 85 ++++++++++++++++--- src/Interpreters/TreeCNFConverter.h | 13 ++- src/Interpreters/TreeOptimizer.cpp | 18 ++-- ..._constraints_simple_optimization.reference | 8 +- .../01623_constraints_column_swap.reference | 2 +- .../0_stateless/01626_cnf_fuzz_long.python | 73 ++++++++++++++++ .../0_stateless/01626_cnf_fuzz_long.reference | 1 + .../0_stateless/01626_cnf_fuzz_long.sh | 10 +++ .../0_stateless/01626_cnf_test.reference | 8 +- tests/queries/0_stateless/01626_cnf_test.sql | 6 +- 10 files changed, 194 insertions(+), 30 deletions(-) create mode 100644 tests/queries/0_stateless/01626_cnf_fuzz_long.python create mode 100644 tests/queries/0_stateless/01626_cnf_fuzz_long.reference create mode 100755 tests/queries/0_stateless/01626_cnf_fuzz_long.sh diff --git a/src/Interpreters/TreeCNFConverter.cpp b/src/Interpreters/TreeCNFConverter.cpp index 46002f5be99..a6b46c46589 100644 --- a/src/Interpreters/TreeCNFConverter.cpp +++ b/src/Interpreters/TreeCNFConverter.cpp @@ -1,7 +1,9 @@ #include #include #include -#include +#include +#include + namespace DB { @@ -10,14 +12,37 @@ namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int INCORRECT_QUERY; + extern const int TOO_MANY_TEMPORARY_COLUMNS; } namespace { +bool isLogicalFunction(const ASTFunction & func) +{ + return func.name == "and" || func.name == "or" || func.name == "not"; +} + +size_t countAtoms(const ASTPtr & node) +{ + checkStackSize(); + if (node->as()) + return 1; + + const auto * func = node->as(); + if (func && !isLogicalFunction(*func)) + return 1; + + size_t num_atoms = 0; + for (const auto & child : node->children) + num_atoms += countAtoms(child); + return num_atoms; +} + /// Splits AND(a, b, c) to AND(a, AND(b, c)) for AND/OR void splitMultiLogic(ASTPtr & node) { + checkStackSize(); auto * func = node->as(); if (func && (func->name == "and" || func->name == "or")) @@ -29,9 +54,8 @@ void splitMultiLogic(ASTPtr & node) { ASTPtr res = func->arguments->children[0]->clone(); for (size_t i = 1; i < func->arguments->children.size(); ++i) - { res = makeASTFunction(func->name, res, func->arguments->children[i]->clone()); - } + node = res; } @@ -49,6 +73,7 @@ void splitMultiLogic(ASTPtr & node) /// Push NOT to leafs, remove NOT NOT ... void traversePushNot(ASTPtr & node, bool add_negation) { + checkStackSize(); auto * func = node->as(); if (func && (func->name == "and" || func->name == "or")) @@ -86,14 +111,19 @@ void traversePushNot(ASTPtr & node, bool add_negation) } /// Push Or inside And (actually pull AND to top) -void traversePushOr(ASTPtr & node) +bool traversePushOr(ASTPtr & node, size_t num_atoms, size_t max_atoms) { + if (max_atoms && num_atoms > max_atoms) + return false; + + checkStackSize(); auto * func = node->as(); if (func && (func->name == "or" || func->name == "and")) { for (auto & child : func->arguments->children) - traversePushOr(child); + if (!traversePushOr(child, num_atoms, max_atoms)) + return false; } if (func && func->name == "or") @@ -105,15 +135,15 @@ void traversePushOr(ASTPtr & node) auto & child = func->arguments->children[i]; auto * and_func = child->as(); if (and_func && and_func->name == "and") - { and_node_id = i; - } } - if (and_node_id == func->arguments->children.size()) - return; - const size_t other_node_id = 1 - and_node_id; + if (and_node_id == func->arguments->children.size()) + return true; + + const size_t other_node_id = 1 - and_node_id; const auto * and_func = func->arguments->children[and_node_id]->as(); + auto a = func->arguments->children[other_node_id]; auto b = and_func->arguments->children[0]; auto c = and_func->arguments->children[1]; @@ -124,13 +154,19 @@ void traversePushOr(ASTPtr & node) makeASTFunction("or", a->clone(), b), makeASTFunction("or", a, c)); - traversePushOr(node); + /// Count all atoms from 'a', because it was cloned. + num_atoms += countAtoms(a); + return traversePushOr(node, num_atoms, max_atoms); } + + return true; } /// transform ast into cnf groups void traverseCNF(const ASTPtr & node, CNFQuery::AndGroup & and_group, CNFQuery::OrGroup & or_group) { + checkStackSize(); + auto * func = node->as(); if (func && func->name == "and") { @@ -171,13 +207,22 @@ void traverseCNF(const ASTPtr & node, CNFQuery::AndGroup & result) } -CNFQuery TreeCNFConverter::toCNF(const ASTPtr & query) +std::optional TreeCNFConverter::tryConvertToCNF( + const ASTPtr & query, size_t max_growth_multipler) { auto cnf = query->clone(); + size_t num_atoms = countAtoms(cnf); splitMultiLogic(cnf); traversePushNot(cnf, false); - traversePushOr(cnf); + + size_t max_atoms = max_growth_multipler + ? std::max(MAX_ATOMS_WITHOUT_CHECK, num_atoms * max_growth_multipler) + : 0; + + if (!traversePushOr(cnf, num_atoms, max_atoms)) + return {}; + CNFQuery::AndGroup and_group; traverseCNF(cnf, and_group); @@ -186,6 +231,18 @@ CNFQuery TreeCNFConverter::toCNF(const ASTPtr & query) return result; } +CNFQuery TreeCNFConverter::toCNF( + const ASTPtr & query, size_t max_growth_multipler) +{ + auto cnf = tryConvertToCNF(query, max_growth_multipler); + if (!cnf) + throw Exception(ErrorCodes::TOO_MANY_TEMPORARY_COLUMNS, + "Cannot expression '{}' to CNF, because it produces to many clauses." + "Size of formula inCNF can be exponential of size of source formula."); + + return *cnf; +} + ASTPtr TreeCNFConverter::fromCNF(const CNFQuery & cnf) { const auto & groups = cnf.getStatements(); @@ -208,7 +265,7 @@ ASTPtr TreeCNFConverter::fromCNF(const CNFQuery & cnf) auto * func = or_groups.back()->as(); for (const auto & atom : group) { - if ((*group.begin()).negative) + if (atom.negative) func->arguments->children.push_back(makeASTFunction("not", atom.ast->clone())); else func->arguments->children.push_back(atom.ast->clone()); diff --git a/src/Interpreters/TreeCNFConverter.h b/src/Interpreters/TreeCNFConverter.h index ba7fb299644..52f997d83c9 100644 --- a/src/Interpreters/TreeCNFConverter.h +++ b/src/Interpreters/TreeCNFConverter.h @@ -145,8 +145,19 @@ private: class TreeCNFConverter { public: + static constexpr size_t DEFAULT_MAX_GROWTH_MULTIPLIER = 20; + static constexpr size_t MAX_ATOMS_WITHOUT_CHECK = 200; - static CNFQuery toCNF(const ASTPtr & query); + /// @max_growth_multipler means that it's allowed to grow size of formula only + /// in that amount of times. It's needed to avoid exponential explosion of formula. + /// CNF of boolean formula with N clauses can have 2^N clauses. + /// If amout of atomic formulas will be exceded nullopt will be returned. + /// 0 - means unlimited. + static std::optional tryConvertToCNF( + const ASTPtr & query, size_t max_growth_multipler = DEFAULT_MAX_GROWTH_MULTIPLIER); + + static CNFQuery toCNF( + const ASTPtr & query, size_t max_growth_multipler = DEFAULT_MAX_GROWTH_MULTIPLIER); static ASTPtr fromCNF(const CNFQuery & cnf); }; diff --git a/src/Interpreters/TreeOptimizer.cpp b/src/Interpreters/TreeOptimizer.cpp index 1b816984647..fd53e6a0b7f 100644 --- a/src/Interpreters/TreeOptimizer.cpp +++ b/src/Interpreters/TreeOptimizer.cpp @@ -564,13 +564,20 @@ void optimizeSubstituteColumn(ASTSelectQuery * select_query, } /// Transform WHERE to CNF for more convenient optimization. -void convertQueryToCNF(ASTSelectQuery * select_query) +bool convertQueryToCNF(ASTSelectQuery * select_query) { if (select_query->where()) { - auto cnf_form = TreeCNFConverter::toCNF(select_query->where()).pushNotInFuntions(); - select_query->refWhere() = TreeCNFConverter::fromCNF(cnf_form); + auto cnf_form = TreeCNFConverter::tryConvertToCNF(select_query->where()); + if (!cnf_form) + return false; + + cnf_form->pushNotInFuntions(); + select_query->refWhere() = TreeCNFConverter::fromCNF(*cnf_form); + return true; } + + return false; } /// Remove duplicated columns from USING(...). @@ -734,10 +741,11 @@ void TreeOptimizer::apply(ASTPtr & query, TreeRewriterResult & result, if (settings.optimize_arithmetic_operations_in_aggregate_functions) optimizeAggregationFunctions(query); + bool converted_to_cnf = false; if (settings.convert_query_to_cnf) - convertQueryToCNF(select_query); + converted_to_cnf = convertQueryToCNF(select_query); - if (settings.convert_query_to_cnf && settings.optimize_using_constraints) + if (converted_to_cnf && settings.optimize_using_constraints) { optimizeWithConstraints(select_query, result.aliases, result.source_columns_set, tables_with_columns, result.metadata_snapshot, settings.optimize_append_index); diff --git a/tests/queries/0_stateless/01622_constraints_simple_optimization.reference b/tests/queries/0_stateless/01622_constraints_simple_optimization.reference index 800d77ea8c6..7e012e1a17b 100644 --- a/tests/queries/0_stateless/01622_constraints_simple_optimization.reference +++ b/tests/queries/0_stateless/01622_constraints_simple_optimization.reference @@ -32,14 +32,14 @@ 1 1 0 -SELECT count() +SELECT count() AS `count()` FROM constraint_test_constants WHERE (c > 100) OR (b > 100) -SELECT count() +SELECT count() AS `count()` FROM constraint_test_constants WHERE c > 100 -SELECT count() +SELECT count() AS `count()` FROM constraint_test_constants WHERE c > 100 -SELECT count() +SELECT count() AS `count()` FROM constraint_test_constants diff --git a/tests/queries/0_stateless/01623_constraints_column_swap.reference b/tests/queries/0_stateless/01623_constraints_column_swap.reference index c287ed073fc..7ae4516fe9e 100644 --- a/tests/queries/0_stateless/01623_constraints_column_swap.reference +++ b/tests/queries/0_stateless/01623_constraints_column_swap.reference @@ -49,5 +49,5 @@ WHERE a = \'c\' SELECT a AS `substring(reverse(b), 1, 1)` FROM column_swap_test_test WHERE a = \'c\' -SELECT toUInt32(s) AS a +SELECT a FROM t_bad_constraint diff --git a/tests/queries/0_stateless/01626_cnf_fuzz_long.python b/tests/queries/0_stateless/01626_cnf_fuzz_long.python new file mode 100644 index 00000000000..10c12d14182 --- /dev/null +++ b/tests/queries/0_stateless/01626_cnf_fuzz_long.python @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +import os +from random import randint, choices +import sys + +CURDIR = os.path.dirname(os.path.realpath(__file__)) +sys.path.insert(0, os.path.join(CURDIR, 'helpers')) + +from pure_http_client import ClickHouseClient + +client = ClickHouseClient() + +N = 10 +create_query = "CREATE TABLE t_cnf_fuzz(" + ", ".join([f"c{i} UInt8" for i in range(N)]) + ") ENGINE = Memory" + +client.query("DROP TABLE IF EXISTS t_cnf_fuzz") +client.query(create_query) + +# Insert all possible combinations of bool columns. +insert_query = "INSERT INTO t_cnf_fuzz VALUES " +for i in range(2**N): + values = [] + cur = i + for _ in range(N): + values.append(cur % 2) + cur //= 2 + + insert_query += "(" + ", ".join(map(lambda x: str(x), values)) + ")" + +client.query(insert_query) + +# Let's try to covert DNF to CNF, +# because it's a worst case in a sense. + +MAX_CLAUSES = 10 +MAX_ATOMS = 5 + +def generate_dnf(): + clauses = [] + num_clauses = randint(1, MAX_CLAUSES) + for _ in range(num_clauses): + num_atoms = randint(1, MAX_ATOMS) + atom_ids = choices(range(N), k=num_atoms) + negates = choices([0, 1], k=num_atoms) + atoms = [f"(NOT c{i})" if neg else f"c{i}" for (i, neg) in zip(atom_ids, negates)] + clauses.append("(" + " AND ".join(atoms) + ")") + + return " OR ".join(clauses) + +select_query = "SELECT count() FROM t_cnf_fuzz WHERE {} SETTINGS convert_query_to_cnf = {}" + +fail_report = """ +Failed query: '{}'. +Result without optimization: {}. +Result with optimization: {}. +""" + +T = 500 +for _ in range(T): + condition = generate_dnf() + + query = select_query.format(condition, 0) + res = client.query(query).strip() + + query_cnf = select_query.format(condition, 1) + res_cnf = client.query(query_cnf).strip() + + if res != res_cnf: + print(fail_report.format(query_cnf, res, res_cnf)) + exit(1) + +client.query("DROP TABLE t_cnf_fuzz") +print("OK") diff --git a/tests/queries/0_stateless/01626_cnf_fuzz_long.reference b/tests/queries/0_stateless/01626_cnf_fuzz_long.reference new file mode 100644 index 00000000000..d86bac9de59 --- /dev/null +++ b/tests/queries/0_stateless/01626_cnf_fuzz_long.reference @@ -0,0 +1 @@ +OK diff --git a/tests/queries/0_stateless/01626_cnf_fuzz_long.sh b/tests/queries/0_stateless/01626_cnf_fuzz_long.sh new file mode 100755 index 00000000000..bdf53cdb252 --- /dev/null +++ b/tests/queries/0_stateless/01626_cnf_fuzz_long.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Tags: no-fasttest, long +# Tag no-fasttest: Require python libraries like scipy, pandas and numpy + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +# We should have correct env vars from shell_config.sh to run this test +python3 "$CURDIR"/01626_cnf_fuzz_long.python diff --git a/tests/queries/0_stateless/01626_cnf_test.reference b/tests/queries/0_stateless/01626_cnf_test.reference index b8de3d3a57c..081215c9fb2 100644 --- a/tests/queries/0_stateless/01626_cnf_test.reference +++ b/tests/queries/0_stateless/01626_cnf_test.reference @@ -6,13 +6,13 @@ FROM cnf_test WHERE (i <= 2) OR (i <= 1) SELECT i FROM cnf_test -WHERE ((i > 4) OR (i > 1) OR (i > 6)) AND ((i > 4) OR (i > 1) OR (i > 5)) AND ((i > 4) OR (i > 6) OR (i > 2)) AND ((i > 4) OR (i > 2) OR (i > 5)) AND ((i > 1) OR (i > 6) OR (i > 3)) AND ((i > 1) OR (i > 3) OR (i > 5)) AND ((i > 6) OR (i > 2) OR (i > 3)) AND ((i > 2) OR (i > 3) OR (i > 5)) +WHERE ((i > 2) OR (i > 5) OR (i > 3)) AND ((i > 2) OR (i > 5) OR (i > 4)) AND ((i > 2) OR (i > 6) OR (i > 3)) AND ((i > 2) OR (i > 6) OR (i > 4)) AND ((i > 1) OR (i > 5) OR (i > 3)) AND ((i > 1) OR (i > 5) OR (i > 4)) AND ((i > 1) OR (i > 6) OR (i > 3)) AND ((i > 1) OR (i > 6) OR (i > 4)) SELECT i FROM cnf_test -WHERE ((i <= 5) OR (i <= 2) OR (i <= 3)) AND ((i <= 5) OR (i <= 2) OR (i <= 4)) AND ((i <= 5) OR (i <= 3) OR (i <= 1)) AND ((i <= 5) OR (i <= 4) OR (i <= 1)) AND ((i <= 2) OR (i <= 3) OR (i <= 6)) AND ((i <= 2) OR (i <= 4) OR (i <= 6)) AND ((i <= 3) OR (i <= 1) OR (i <= 6)) AND ((i <= 4) OR (i <= 1) OR (i <= 6)) +WHERE ((i <= 3) OR (i <= 2) OR (i <= 5)) AND ((i <= 3) OR (i <= 2) OR (i <= 6)) AND ((i <= 3) OR (i <= 5) OR (i <= 1)) AND ((i <= 3) OR (i <= 6) OR (i <= 1)) AND ((i <= 2) OR (i <= 5) OR (i <= 4)) AND ((i <= 2) OR (i <= 6) OR (i <= 4)) AND ((i <= 5) OR (i <= 1) OR (i <= 4)) AND ((i <= 6) OR (i <= 1) OR (i <= 4)) SELECT i FROM cnf_test -WHERE ((i > 4) OR (i > 1) OR (i > 6)) AND ((i > 4) OR (i > 1) OR (i > 5)) AND ((i > 4) OR (i > 6) OR (i > 2)) AND ((i > 4) OR (i > 6) OR (i > 7)) AND ((i > 4) OR (i > 2) OR (i > 5)) AND ((i > 4) OR (i > 7) OR (i > 5)) AND ((i > 1) OR (i > 8) OR (i > 6)) AND ((i > 1) OR (i > 8) OR (i > 5)) AND ((i > 1) OR (i > 6) OR (i > 3)) AND ((i > 1) OR (i > 3) OR (i > 5)) AND ((i > 8) OR (i > 6) OR (i > 2)) AND ((i > 8) OR (i > 6) OR (i > 7)) AND ((i > 8) OR (i > 2) OR (i > 5)) AND ((i > 8) OR (i > 7) OR (i > 5)) AND ((i > 6) OR (i > 2) OR (i > 3)) AND ((i > 6) OR (i > 3) OR (i > 7)) AND ((i > 2) OR (i > 3) OR (i > 5)) AND ((i > 3) OR (i > 7) OR (i > 5)) +WHERE ((i > 2) OR (i > 5) OR (i > 3)) AND ((i > 2) OR (i > 5) OR (i > 4)) AND ((i > 2) OR (i > 5) OR (i > 8)) AND ((i > 2) OR (i > 6) OR (i > 3)) AND ((i > 2) OR (i > 6) OR (i > 4)) AND ((i > 2) OR (i > 6) OR (i > 8)) AND ((i > 1) OR (i > 5) OR (i > 3)) AND ((i > 1) OR (i > 5) OR (i > 4)) AND ((i > 1) OR (i > 5) OR (i > 8)) AND ((i > 1) OR (i > 6) OR (i > 3)) AND ((i > 1) OR (i > 6) OR (i > 4)) AND ((i > 1) OR (i > 6) OR (i > 8)) AND ((i > 5) OR (i > 3) OR (i > 7)) AND ((i > 5) OR (i > 4) OR (i > 7)) AND ((i > 5) OR (i > 8) OR (i > 7)) AND ((i > 6) OR (i > 3) OR (i > 7)) AND ((i > 6) OR (i > 4) OR (i > 7)) AND ((i > 6) OR (i > 8) OR (i > 7)) SELECT i FROM cnf_test -WHERE ((i > 4) OR (i > 8) OR (i > 3)) AND (i <= 5) AND ((i > 1) OR (i > 2) OR (i > 7)) AND (i <= 6) +WHERE ((i > 2) OR (i > 1) OR (i > 7)) AND (i <= 5) AND (i <= 6) AND ((i > 3) OR (i > 4) OR (i > 8)) diff --git a/tests/queries/0_stateless/01626_cnf_test.sql b/tests/queries/0_stateless/01626_cnf_test.sql index e014441cbb3..8db732bc227 100644 --- a/tests/queries/0_stateless/01626_cnf_test.sql +++ b/tests/queries/0_stateless/01626_cnf_test.sql @@ -1,6 +1,8 @@ SET convert_query_to_cnf = 1; -CREATE TABLE cnf_test (i Int64) ENGINE = MergeTree() ORDER BY i; +DROP TABLE IF EXISTS cnf_test; + +CREATE TABLE cnf_test (i Int64) ENGINE = MergeTree() ORDER BY i; EXPLAIN SYNTAX SELECT i FROM cnf_test WHERE NOT ((i > 1) OR (i > 2)); EXPLAIN SYNTAX SELECT i FROM cnf_test WHERE NOT ((i > 1) AND (i > 2)); @@ -12,3 +14,5 @@ EXPLAIN SYNTAX SELECT i FROM cnf_test WHERE NOT (((i > 1) OR (i > 2)) AND ((i > EXPLAIN SYNTAX SELECT i FROM cnf_test WHERE ((i > 1) AND (i > 2) AND (i > 7)) OR ((i > 3) AND (i > 4) AND (i > 8)) OR ((i > 5) AND (i > 6)); EXPLAIN SYNTAX SELECT i FROM cnf_test WHERE ((i > 1) OR (i > 2) OR (i > 7)) AND ((i > 3) OR (i > 4) OR (i > 8)) AND NOT ((i > 5) OR (i > 6)); + +DROP TABLE cnf_test;