ClickHouse/src/Interpreters/TreeCNFConverter.cpp

380 lines
11 KiB
C++
Raw Normal View History

2021-01-04 20:55:32 +00:00
#include <Interpreters/TreeCNFConverter.h>
#include <Parsers/IAST.h>
#include <Parsers/ASTFunction.h>
#include <Poco/Logger.h>
namespace DB
{
/// Splits AND(a, b, c) to AND(a, AND(b, c)) for AND/OR
void splitMultiLogic(ASTPtr & node)
{
auto * func = node->as<ASTFunction>();
if (func && (func->name == "and" || func->name == "or"))
{
if (func->arguments->children.size() > 2)
{
2021-02-14 14:47:15 +00:00
ASTPtr res = func->arguments->children.front()->clone();
2021-01-04 20:55:32 +00:00
for (size_t i = 1; i < func->arguments->children.size(); ++i)
{
2021-02-14 14:47:15 +00:00
res = makeASTFunction(func->name, res, func->arguments->children[i]->clone());
2021-01-04 20:55:32 +00:00
}
node = std::move(res);
}
auto * new_func = node->as<ASTFunction>();
for (auto & child : new_func->arguments->children)
splitMultiLogic(child);
}
}
2021-01-05 20:51:19 +00:00
/// Push NOT to leafs, remove NOT NOT ...
2021-01-04 20:55:32 +00:00
void traversePushNot(ASTPtr & node, bool add_negation)
{
auto * func = node->as<ASTFunction>();
if (func && (func->name == "and" || func->name == "or"))
{
2021-05-04 18:43:58 +00:00
if (add_negation)
{
2021-05-05 08:51:25 +00:00
ASSERT(func->arguments->size() == 2)
2021-01-04 20:55:32 +00:00
/// apply De Morgan's Law
node = makeASTFunction(
(func->name == "and" ? "or" : "and"),
2021-02-14 14:47:15 +00:00
func->arguments->children[0]->clone(),
func->arguments->children[1]->clone());
2021-01-04 20:55:32 +00:00
}
auto * new_func = node->as<ASTFunction>();
for (auto & child : new_func->arguments->children)
traversePushNot(child, add_negation);
}
else if (func && func->name == "not")
{
2021-05-05 08:51:25 +00:00
ASSERT(func->arguments->size() == 1)
2021-01-04 20:55:32 +00:00
/// delete NOT
2021-02-14 14:47:15 +00:00
node = func->arguments->children[0]->clone();
2021-01-04 20:55:32 +00:00
traversePushNot(node, !add_negation);
}
else
{
if (add_negation)
2021-02-14 14:47:15 +00:00
node = makeASTFunction("not", node->clone());
2021-01-04 20:55:32 +00:00
}
}
void findOrs(ASTPtr & node, std::vector<std::reference_wrapper<ASTPtr>> & ors)
{
auto * func = node->as<ASTFunction>();
if (func && func->name == "or")
ors.push_back(node);
if (func)
{
for (auto & child : func->arguments->children)
findOrs(child, ors);
}
}
/// Push Or inside And (actually pull AND to top)
void pushOr(ASTPtr & query)
{
std::vector<std::reference_wrapper<ASTPtr>> ors;
findOrs(query, ors);
while (!ors.empty())
{
std::reference_wrapper<ASTPtr> or_node = ors.back();
ors.pop_back();
auto * or_func = or_node.get()->as<ASTFunction>();
ASSERT(or_func)
ASSERT(or_func->name == "or")
2021-05-05 08:51:25 +00:00
ASSERT(or_func->arguments->children.size() == 2)
2021-01-04 20:55:32 +00:00
/// find or upper than and
size_t and_node_id = or_func->arguments->children.size();
for (size_t i = 0; i < or_func->arguments->children.size(); ++i)
{
auto & child = or_func->arguments->children[i];
auto * and_func = child->as<ASTFunction>();
2021-05-04 21:36:30 +00:00
if (and_func && and_func->name == "and")
2021-01-04 20:55:32 +00:00
{
and_node_id = i;
}
}
if (and_node_id == or_func->arguments->children.size())
continue;
const size_t other_node_id = 1 - and_node_id;
auto and_func = or_func->arguments->children[and_node_id]->as<ASTFunction>();
ASSERT(and_func)
ASSERT(and_func->name == "and")
2021-05-05 08:51:25 +00:00
ASSERT(and_func->arguments->children.size() == 2)
2021-01-04 20:55:32 +00:00
auto a = or_func->arguments->children[other_node_id];
auto b = and_func->arguments->children[0];
auto c = and_func->arguments->children[1];
/// apply the distributive law ( a or (b and c) -> (a or b) and (a or c) )
or_node.get() = makeASTFunction(
"and",
2021-02-14 14:47:15 +00:00
makeASTFunction("or", a->clone(), b->clone()),
makeASTFunction("or", a->clone(), c->clone()));
2021-01-04 20:55:32 +00:00
/// add new ors to stack
auto * new_func = or_node.get()->as<ASTFunction>();
for (auto & new_or : new_func->arguments->children)
ors.push_back(new_or);
}
}
/// transform ast into cnf groups
void traverseCNF(const ASTPtr & node, CNFQuery::AndGroup & and_group, CNFQuery::OrGroup & or_group)
{
auto * func = node->as<ASTFunction>();
if (func && func->name == "and")
{
for (auto & child : func->arguments->children)
{
CNFQuery::OrGroup group;
traverseCNF(child, and_group, group);
if (!group.empty())
and_group.insert(std::move(group));
}
}
else if (func && func->name == "or")
{
for (auto & child : func->arguments->children)
{
traverseCNF(child, and_group, or_group);
}
}
2021-03-04 12:11:43 +00:00
else if (func && func->name == "not")
{
or_group.insert(CNFQuery::AtomicFormula{true, func->arguments->children.front()});
}
2021-01-04 20:55:32 +00:00
else
{
2021-03-04 12:11:43 +00:00
or_group.insert(CNFQuery::AtomicFormula{false, node});
2021-01-04 20:55:32 +00:00
}
}
void traverseCNF(const ASTPtr & node, CNFQuery::AndGroup & result)
{
CNFQuery::OrGroup or_group;
traverseCNF(node, result, or_group);
if (!or_group.empty())
result.insert(or_group);
}
CNFQuery TreeCNFConverter::toCNF(const ASTPtr & query)
{
auto cnf = query->clone();
splitMultiLogic(cnf);
traversePushNot(cnf, false);
pushOr(cnf);
CNFQuery::AndGroup and_group;
traverseCNF(cnf, and_group);
CNFQuery result{std::move(and_group)};
Poco::Logger::get("CNF CONVERSION").information("DONE: " + result.dump());
return result;
}
ASTPtr TreeCNFConverter::fromCNF(const CNFQuery & cnf)
{
const auto & groups = cnf.getStatements();
if (groups.empty())
return nullptr;
ASTs or_groups;
for (const auto & group : groups)
{
if (group.size() == 1)
2021-03-04 12:11:43 +00:00
{
if ((*group.begin()).negative)
or_groups.push_back(makeASTFunction("not", (*group.begin()).ast->clone()));
else
or_groups.push_back((*group.begin()).ast->clone());
}
2021-01-04 20:55:32 +00:00
else if (group.size() > 1)
{
or_groups.push_back(makeASTFunction("or"));
auto * func = or_groups.back()->as<ASTFunction>();
2021-03-04 12:11:43 +00:00
for (const auto & atom : group)
{
if ((*group.begin()).negative)
func->arguments->children.push_back(makeASTFunction("not", atom.ast->clone()));
else
func->arguments->children.push_back(atom.ast->clone());
}
2021-01-04 20:55:32 +00:00
}
}
if (or_groups.size() == 1)
return or_groups.front();
ASTPtr res = makeASTFunction("and");
auto * func = res->as<ASTFunction>();
for (const auto & group : or_groups)
func->arguments->children.push_back(group);
return res;
}
2021-03-04 12:11:43 +00:00
void pushPullNotInAtom(CNFQuery::AtomicFormula & atom, const std::map<std::string, std::string> & inverse_relations)
{
auto * func = atom.ast->as<ASTFunction>();
if (!func)
return;
if (auto it = inverse_relations.find(func->name); it != std::end(inverse_relations))
{
/// inverse func
atom.ast = atom.ast->clone();
auto * new_func = atom.ast->as<ASTFunction>();
new_func->name = it->second;
/// add not
atom.negative = !atom.negative;
}
}
void pullNotOut(CNFQuery::AtomicFormula & atom)
2021-01-05 20:51:19 +00:00
{
static const std::map<std::string, std::string> inverse_relations = {
{"notEquals", "equals"},
{"greaterOrEquals", "less"},
{"greater", "lessOrEquals"},
{"notIn", "in"},
{"notLike", "like"},
{"notEmpty", "empty"},
};
2021-03-04 12:11:43 +00:00
pushPullNotInAtom(atom, inverse_relations);
2021-01-05 20:51:19 +00:00
}
2021-03-04 12:11:43 +00:00
void pushNotIn(CNFQuery::AtomicFormula & atom)
2021-01-05 20:51:19 +00:00
{
2021-03-04 12:11:43 +00:00
if (!atom.negative)
return;
2021-01-05 20:51:19 +00:00
static const std::map<std::string, std::string> inverse_relations = {
{"equals", "notEquals"},
{"less", "greaterOrEquals"},
{"lessOrEquals", "greater"},
{"in", "notIn"},
{"like", "notLike"},
{"empty", "notEmpty"},
2021-03-04 12:11:43 +00:00
{"notEquals", "equals"},
{"greaterOrEquals", "less"},
{"greater", "lessOrEquals"},
{"notIn", "in"},
{"notLike", "like"},
{"notEmpty", "empty"},
2021-01-05 20:51:19 +00:00
};
2021-03-04 12:11:43 +00:00
pushPullNotInAtom(atom, inverse_relations);
2021-01-05 20:51:19 +00:00
}
CNFQuery & CNFQuery::pullNotOutFunctions()
{
2021-03-04 12:11:43 +00:00
transformAtoms([](const AtomicFormula & atom) -> AtomicFormula
{
AtomicFormula result{atom.negative, atom.ast->clone()};
pullNotOut(result);
return result;
});
2021-01-05 20:51:19 +00:00
return *this;
}
CNFQuery & CNFQuery::pushNotInFuntions()
{
2021-03-04 12:11:43 +00:00
transformAtoms([](const AtomicFormula & atom) -> AtomicFormula
2021-01-05 20:51:19 +00:00
{
2021-03-04 12:11:43 +00:00
AtomicFormula result{atom.negative, atom.ast->clone()};
pushNotIn(result);
2021-01-05 20:51:19 +00:00
return result;
});
return *this;
}
2021-05-05 08:51:25 +00:00
namespace
{
CNFQuery::AndGroup reduceOnce(const CNFQuery::AndGroup & groups)
{
CNFQuery::AndGroup result;
for (const CNFQuery::OrGroup & group : groups)
{
CNFQuery::OrGroup copy(group);
bool inserted = false;
for (const CNFQuery::AtomicFormula & atom : group)
{
copy.erase(atom);
CNFQuery::AtomicFormula negative_atom(atom);
negative_atom.negative = !atom.negative;
copy.insert(negative_atom);
if (groups.contains(copy))
{
copy.erase(negative_atom);
result.insert(copy);
inserted = true;
break;
}
copy.erase(negative_atom);
copy.insert(atom);
}
if (!inserted)
result.insert(group);
}
return result;
}
}
CNFQuery & CNFQuery::reduce()
{
while (true)
{
AndGroup new_statements = reduceOnce(statements);
if (statements == new_statements)
return *this;
else
statements = new_statements;
}
}
2021-01-04 20:55:32 +00:00
std::string CNFQuery::dump() const
{
2021-05-04 18:43:58 +00:00
WriteBufferFromOwnString res;
2021-01-04 20:55:32 +00:00
bool first = true;
for (const auto & group : statements)
{
if (!first)
res << " AND ";
first = false;
res << "(";
bool first_in_group = true;
2021-03-04 12:11:43 +00:00
for (const auto & atom : group)
2021-01-04 20:55:32 +00:00
{
if (!first_in_group)
res << " OR ";
first_in_group = false;
2021-03-04 12:11:43 +00:00
if (atom.negative)
res << " NOT ";
res << atom.ast->getColumnName();
2021-01-04 20:55:32 +00:00
}
res << ")";
}
return res.str();
}
}