mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-28 18:42:26 +00:00
add condition
This commit is contained in:
parent
48545bc390
commit
f53a3be4a8
534
src/Storages/MergeTree/CommonCondition.cpp
Normal file
534
src/Storages/MergeTree/CommonCondition.cpp
Normal file
@ -0,0 +1,534 @@
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
|
||||
#include "Core/Block.h"
|
||||
#include "Core/Field.h"
|
||||
#include "IO/ReadBuffer.h"
|
||||
#include "Interpreters/Context_fwd.h"
|
||||
#include "Parsers/ASTExpressionList.h"
|
||||
#include "Parsers/ASTFunctionWithKeyValueArguments.h"
|
||||
#include "Parsers/ASTIdentifier.h"
|
||||
#include "Parsers/ASTIdentifier_fwd.h"
|
||||
#include "Parsers/ASTLiteral.h"
|
||||
#include "Parsers/ASTOrderByElement.h"
|
||||
#include "Parsers/ASTSelectQuery.h"
|
||||
#include "Parsers/ASTSetQuery.h"
|
||||
#include "Parsers/ASTTablesInSelectQuery.h"
|
||||
#include "Parsers/Access/ASTCreateUserQuery.h"
|
||||
#include "Parsers/Access/ASTRolesOrUsersSet.h"
|
||||
#include "Parsers/Access/ASTSettingsProfileElement.h"
|
||||
#include "Parsers/IAST_fwd.h"
|
||||
|
||||
#include <Storages/MergeTree/CommonCondition.h>
|
||||
#include <Storages/MergeTree/KeyCondition.h>
|
||||
|
||||
#include "Storages/SelectQueryInfo.h"
|
||||
#include "base/logger_useful.h"
|
||||
#include "base/types.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
namespace Condition
|
||||
{
|
||||
|
||||
CommonCondition::CommonCondition(const SelectQueryInfo & query_info,
|
||||
ContextPtr context)
|
||||
{
|
||||
buildRPN(query_info, context);
|
||||
index_is_useful = matchAllRPNS();
|
||||
}
|
||||
|
||||
bool CommonCondition::alwaysUnknownOrTrue() const
|
||||
{
|
||||
return !index_is_useful;
|
||||
}
|
||||
|
||||
float CommonCondition::getComparisonDistance() const
|
||||
{
|
||||
if (where_query_type)
|
||||
{
|
||||
return ann_expr->distance;
|
||||
}
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported method for this query type");
|
||||
}
|
||||
|
||||
std::vector<float> CommonCondition::getTargetVector() const
|
||||
{
|
||||
return ann_expr->target;
|
||||
}
|
||||
|
||||
String CommonCondition::getColumnName() const
|
||||
{
|
||||
return ann_expr->column_name;
|
||||
}
|
||||
|
||||
String CommonCondition::getMetric() const
|
||||
{
|
||||
return ann_expr->metric_name;
|
||||
}
|
||||
|
||||
size_t CommonCondition::getSpaceDim() const
|
||||
{
|
||||
return ann_expr->target.size();
|
||||
}
|
||||
|
||||
float CommonCondition::getPForLpDistance() const
|
||||
{
|
||||
return ann_expr->p_for_lp_dist;
|
||||
}
|
||||
|
||||
bool CommonCondition::queryHasWhereClause() const
|
||||
{
|
||||
return where_query_type;
|
||||
}
|
||||
|
||||
bool CommonCondition::queryHasOrderByClause() const
|
||||
{
|
||||
return order_by_query_type && has_limit;
|
||||
}
|
||||
|
||||
std::optional<UInt64> CommonCondition::getLimitLength() const
|
||||
{
|
||||
return has_limit ? std::optional<UInt64>(limit_expr->length) : std::nullopt;
|
||||
}
|
||||
|
||||
String CommonCondition::getSettingsStr() const
|
||||
{
|
||||
return ann_index_params;
|
||||
}
|
||||
|
||||
void CommonCondition::buildRPN(const SelectQueryInfo & query, ContextPtr context)
|
||||
{
|
||||
block_with_constants = KeyCondition::getBlockWithConstants(query.query, query.syntax_analyzer_result, context);
|
||||
|
||||
const auto & select = query.query->as<ASTSelectQuery &>();
|
||||
|
||||
if (select.prewhere())
|
||||
{
|
||||
traverseAST(select.prewhere(), rpn_prewhere_clause);
|
||||
}
|
||||
|
||||
if (select.where())
|
||||
{
|
||||
traverseAST(select.where(), rpn_where_clause);
|
||||
}
|
||||
|
||||
if (select.limitLength())
|
||||
{
|
||||
traverseAST(select.limitLength(), rpn_limit_clause);
|
||||
}
|
||||
|
||||
if (select.settings())
|
||||
{
|
||||
parseSettings(select.settings());
|
||||
}
|
||||
|
||||
if (select.orderBy())
|
||||
{
|
||||
if (const auto * expr_list = select.orderBy()->as<ASTExpressionList>())
|
||||
{
|
||||
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
|
||||
{
|
||||
traverseAST(order_by_element->children.front(), rpn_order_by_clause);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::reverse(rpn_prewhere_clause.begin(), rpn_prewhere_clause.end());
|
||||
std::reverse(rpn_where_clause.begin(), rpn_where_clause.end());
|
||||
std::reverse(rpn_order_by_clause.begin(), rpn_order_by_clause.end());
|
||||
}
|
||||
|
||||
void CommonCondition::traverseAST(const ASTPtr & node, RPN & rpn)
|
||||
{
|
||||
if (const auto * func = node->as<ASTFunction>())
|
||||
{
|
||||
const ASTs & args = func->arguments->children;
|
||||
|
||||
for (const auto& arg : args)
|
||||
{
|
||||
traverseAST(arg, rpn);
|
||||
}
|
||||
}
|
||||
|
||||
RPNElement element;
|
||||
|
||||
if (!traverseAtomAST(node, element))
|
||||
{
|
||||
element.function = RPNElement::FUNCTION_UNKNOWN;
|
||||
}
|
||||
|
||||
rpn.emplace_back(std::move(element));
|
||||
}
|
||||
|
||||
bool CommonCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
|
||||
{
|
||||
|
||||
if (const auto * order_by_element = node->as<ASTOrderByElement>())
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_ORDER_BY_ELEMENT;
|
||||
out.func_name = "order by elemnet";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (const auto * function = node->as<ASTFunction>())
|
||||
{
|
||||
// Set the name
|
||||
out.func_name = function->name;
|
||||
|
||||
// TODO: Add support for LpDistance
|
||||
if (function->name == "L1Distance" ||
|
||||
function->name == "L2Distance" ||
|
||||
function->name == "LinfDistance" ||
|
||||
function->name == "cosineDistance" ||
|
||||
function->name == "dotProduct" ||
|
||||
function->name == "LpDistance")
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_DISTANCE;
|
||||
}
|
||||
else if (function->name == "tuple")
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_TUPLE;
|
||||
}
|
||||
else if (function->name == "less" ||
|
||||
function->name == "greater" ||
|
||||
function->name == "lessOrEquals" ||
|
||||
function->name == "greaterOrEquals")
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_COMPARISON;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
// Match identifier
|
||||
else if (const auto * identifier = node->as<ASTIdentifier>())
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_IDENTIFIER;
|
||||
out.identifier.emplace(identifier->name());
|
||||
out.func_name = "column identifier";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if we have constants behind the node
|
||||
{
|
||||
Field const_value;
|
||||
DataTypePtr const_type;
|
||||
|
||||
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
|
||||
{
|
||||
/// Check constant type (use Float64 because all Fields implementation contains Float64 (for Float32 too))
|
||||
if (const_value.getType() == Field::Types::Float64)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
|
||||
out.float_literal.emplace(const_value.get<Float32>());
|
||||
out.func_name = "Float literal";
|
||||
return true;
|
||||
}
|
||||
if (const_value.getType() == Field::Types::UInt64)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_INT_LITERAL;
|
||||
out.int_literal.emplace(const_value.get<UInt64>());
|
||||
out.func_name = "Int literal";
|
||||
return true;
|
||||
}
|
||||
if (const_value.getType() == Field::Types::Int64)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_INT_LITERAL;
|
||||
out.int_literal.emplace(const_value.get<Int64>());
|
||||
out.func_name = "Int literal";
|
||||
return true;
|
||||
}
|
||||
if (const_value.getType() == Field::Types::String)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_STRING;
|
||||
out.identifier.emplace(const_value.get<String>());
|
||||
out.func_name = "setting string";
|
||||
return true;
|
||||
}
|
||||
if (const_value.getType() == Field::Types::Tuple)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_LITERAL_TUPLE;
|
||||
out.tuple_literal = const_value.get<Tuple>();
|
||||
out.func_name = "Tuple literal";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CommonCondition::matchAllRPNS()
|
||||
{
|
||||
ANNExpression expr_prewhere;
|
||||
ANNExpression expr_where;
|
||||
ANNExpression expr_order_by;
|
||||
LimitExpression expr_limit;
|
||||
bool prewhere_is_valid = matchRPNWhere(rpn_prewhere_clause, expr_prewhere);
|
||||
bool where_is_valid = matchRPNWhere(rpn_where_clause, expr_where);
|
||||
bool limit_is_valid = matchRPNLimit(rpn_limit_clause, expr_limit);
|
||||
bool order_by_is_valid = matchRPNOrderBy(rpn_order_by_clause, expr_order_by);
|
||||
|
||||
// Unxpected situation
|
||||
if (prewhere_is_valid && where_is_valid)
|
||||
{
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Have both where and prewhere valid clauses - is not supported");
|
||||
}
|
||||
|
||||
if (prewhere_is_valid || where_is_valid)
|
||||
{
|
||||
ann_expr = std::move(where_is_valid ? expr_where : expr_prewhere);
|
||||
where_query_type = true;
|
||||
}
|
||||
if (order_by_is_valid)
|
||||
{
|
||||
ann_expr = std::move(expr_order_by);
|
||||
order_by_query_type = true;
|
||||
}
|
||||
if (limit_is_valid)
|
||||
{
|
||||
limit_expr = std::move(expr_limit);
|
||||
has_limit = true;
|
||||
}
|
||||
|
||||
if (where_query_type && (has_limit && order_by_query_type))
|
||||
{
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR,
|
||||
"The query with Valid Where Clause and valid OrderBy clause - is not supported");
|
||||
}
|
||||
|
||||
return where_query_type || (has_limit && order_by_query_type);
|
||||
}
|
||||
|
||||
bool CommonCondition::matchRPNLimit(RPN & rpn, LimitExpression & expr)
|
||||
{
|
||||
if (rpn.size() != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (rpn.front().function == RPNElement::FUNCTION_INT_LITERAL)
|
||||
{
|
||||
expr.length = rpn.front().int_literal.value();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CommonCondition::parseSettings(const ASTPtr & node)
|
||||
{
|
||||
if (const auto * set = node->as<ASTSetQuery>())
|
||||
{
|
||||
for (const auto & change : set->changes)
|
||||
{
|
||||
if (change.name == "ann_index_params")
|
||||
{
|
||||
ann_index_params = change.value.get<String>();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
ann_index_params = "";
|
||||
}
|
||||
|
||||
bool CommonCondition::matchRPNOrderBy(RPN & rpn, ANNExpression & expr)
|
||||
{
|
||||
if (rpn.size() < 3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto iter = rpn.begin();
|
||||
auto end = rpn.end();
|
||||
bool identifier_found = false;
|
||||
|
||||
return CommonCondition::matchMainParts(iter, end, expr, identifier_found);
|
||||
}
|
||||
|
||||
bool CommonCondition::matchMainParts(RPN::iterator & iter, RPN::iterator & end,
|
||||
ANNExpression & expr, bool & identifier_found)
|
||||
{
|
||||
|
||||
if (iter->function != RPNElement::FUNCTION_DISTANCE)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
expr.metric_name = iter->func_name;
|
||||
++iter;
|
||||
|
||||
if (expr.metric_name == "LpDistance")
|
||||
{
|
||||
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL &&
|
||||
iter->function != RPNElement::FUNCTION_INT_LITERAL)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
expr.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter);
|
||||
++iter;
|
||||
}
|
||||
|
||||
|
||||
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
|
||||
{
|
||||
identifier_found = true;
|
||||
expr.column_name = getIdentifierOrPanic(iter);
|
||||
++iter;
|
||||
}
|
||||
|
||||
if (iter->function == RPNElement::FUNCTION_TUPLE)
|
||||
{
|
||||
++iter;
|
||||
}
|
||||
|
||||
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
|
||||
{
|
||||
for (const auto & value : iter->tuple_literal.value())
|
||||
{
|
||||
expr.target.emplace_back(value.get<float>());
|
||||
}
|
||||
++iter;
|
||||
}
|
||||
|
||||
|
||||
while (iter != end)
|
||||
{
|
||||
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
|
||||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
|
||||
{
|
||||
expr.target.emplace_back(getFloatOrIntLiteralOrPanic(iter));
|
||||
}
|
||||
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
|
||||
{
|
||||
if (identifier_found)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
expr.column_name = getIdentifierOrPanic(iter);
|
||||
identifier_found = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
++iter;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommonCondition::matchRPNWhere(RPN & rpn, ANNExpression & expr)
|
||||
{
|
||||
const size_t minimal_elemets_count = 6;// At least 6 AST nodes in querry
|
||||
if (rpn.size() < minimal_elemets_count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto iter = rpn.begin();
|
||||
bool identifier_found = false;
|
||||
|
||||
// Query starts from operator less
|
||||
if (iter->function != RPNElement::FUNCTION_COMPARISON)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool greater_case = iter->func_name == "greater" || iter->func_name == "greaterOrEquals";
|
||||
const bool less_case = iter->func_name == "less" || iter->func_name == "lessOrEquals";
|
||||
|
||||
++iter;
|
||||
|
||||
if (less_case)
|
||||
{
|
||||
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
expr.distance = getFloatOrIntLiteralOrPanic(iter);
|
||||
++iter;
|
||||
|
||||
}
|
||||
else if (!greater_case)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto end = rpn.end();
|
||||
if (!matchMainParts(iter, end, expr, identifier_found))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Final checks of correctness
|
||||
|
||||
if (!identifier_found || expr.target.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (greater_case)
|
||||
{
|
||||
if (expr.target.size() < 2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
expr.distance = expr.target.back();
|
||||
expr.target.pop_back();
|
||||
}
|
||||
|
||||
// Querry is ok
|
||||
return true;
|
||||
}
|
||||
|
||||
String CommonCondition::getIdentifierOrPanic(RPN::iterator& iter)
|
||||
{
|
||||
String identifier;
|
||||
try
|
||||
{
|
||||
identifier = std::move(iter->identifier.value());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
CommonCondition::panicIfWrongBuiltRPN();
|
||||
}
|
||||
return identifier;
|
||||
}
|
||||
|
||||
float CommonCondition::getFloatOrIntLiteralOrPanic(RPN::iterator& iter)
|
||||
{
|
||||
if (iter->float_literal.has_value())
|
||||
{
|
||||
return iter->float_literal.value();
|
||||
}
|
||||
if (iter->int_literal.has_value())
|
||||
{
|
||||
return static_cast<float>(iter->int_literal.value());
|
||||
}
|
||||
CommonCondition::panicIfWrongBuiltRPN();
|
||||
}
|
||||
|
||||
void CommonCondition::panicIfWrongBuiltRPN()
|
||||
{
|
||||
LOG_DEBUG(&Poco::Logger::get("CommonCondition"), "Wrong parsing of AST");
|
||||
throw Exception(
|
||||
"Wrong parsed AST in buildRPN\n", DB::ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
175
src/Storages/MergeTree/CommonCondition.h
Normal file
175
src/Storages/MergeTree/CommonCondition.h
Normal file
@ -0,0 +1,175 @@
|
||||
#pragma once
|
||||
|
||||
#include <Storages/MergeTree/KeyCondition.h>
|
||||
#include "base/types.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace Condition
|
||||
{
|
||||
|
||||
class CommonCondition
|
||||
{
|
||||
public:
|
||||
CommonCondition(const SelectQueryInfo & query_info,
|
||||
ContextPtr context);
|
||||
|
||||
bool alwaysUnknownOrTrue() const;
|
||||
|
||||
float getComparisonDistance() const;
|
||||
|
||||
std::vector<float> getTargetVector() const;
|
||||
|
||||
size_t getSpaceDim() const;
|
||||
|
||||
String getColumnName() const;
|
||||
|
||||
String getMetric() const;
|
||||
|
||||
float getPForLpDistance() const;
|
||||
|
||||
bool queryHasOrderByClause() const;
|
||||
|
||||
bool queryHasWhereClause() const;
|
||||
|
||||
std::optional<UInt64> getLimitLength() const;
|
||||
|
||||
String getSettingsStr() const;
|
||||
|
||||
private:
|
||||
// Type of the vector to use as a target in the distance function
|
||||
using Target = std::vector<float>;
|
||||
|
||||
// Extracted data from the query like WHERE L2Distance(column_name, target) < distance
|
||||
struct ANNExpression
|
||||
{
|
||||
Target target;
|
||||
float distance = -1.0;
|
||||
String metric_name = "Unknown"; // Metric name, maybe some Enum for all indices
|
||||
String column_name = "Unknown"; // Coloumn name stored in IndexGranule
|
||||
float p_for_lp_dist = -1.0; // The P parametr for Lp Distance
|
||||
};
|
||||
|
||||
struct LimitExpression
|
||||
{
|
||||
Int64 length;
|
||||
};
|
||||
|
||||
using ANNExprOpt = std::optional<ANNExpression>;
|
||||
using LimitExprOpt = std::optional<LimitExpression>;
|
||||
struct RPNElement
|
||||
{
|
||||
enum Function
|
||||
{
|
||||
// l2 dist
|
||||
FUNCTION_DISTANCE,
|
||||
|
||||
//tuple(10, 15)
|
||||
FUNCTION_TUPLE,
|
||||
|
||||
// Operator <, >
|
||||
FUNCTION_COMPARISON,
|
||||
|
||||
// Numeric float value
|
||||
FUNCTION_FLOAT_LITERAL,
|
||||
|
||||
// Numeric int value
|
||||
FUNCTION_INT_LITERAL,
|
||||
|
||||
// Column identifier
|
||||
FUNCTION_IDENTIFIER,
|
||||
|
||||
// Unknown, can be any value
|
||||
FUNCTION_UNKNOWN,
|
||||
|
||||
FUNCTION_STRING,
|
||||
|
||||
FUNCTION_LITERAL_TUPLE,
|
||||
|
||||
FUNCTION_ORDER_BY_ELEMENT,
|
||||
};
|
||||
|
||||
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
|
||||
: function(function_), func_name("Unknown"), float_literal(std::nullopt), identifier(std::nullopt) {}
|
||||
|
||||
Function function;
|
||||
String func_name;
|
||||
|
||||
std::optional<float> float_literal;
|
||||
std::optional<String> identifier;
|
||||
std::optional<int64_t> int_literal{std::nullopt};
|
||||
std::optional<Tuple> tuple_literal{std::nullopt};
|
||||
|
||||
UInt32 dim{0};
|
||||
};
|
||||
|
||||
using RPN = std::vector<RPNElement>;
|
||||
|
||||
void buildRPN(const SelectQueryInfo & query, ContextPtr context);
|
||||
|
||||
// Util functions for the traversal of AST
|
||||
void traverseAST(const ASTPtr & node, RPN & rpn);
|
||||
// Return true if we can identify our node type
|
||||
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
|
||||
|
||||
// Checks that at least one rpn is matching for index
|
||||
// New RPNs for other query types can be added here
|
||||
bool matchAllRPNS();
|
||||
|
||||
/* Returns true and stores ANNExpr if the query matches the template:
|
||||
* WHERE DistFunc(column_name, tuple(float_1, float_2, ..., float_dim)) < float_literal */
|
||||
static bool matchRPNWhere(RPN & rpn, ANNExpression & expr);
|
||||
|
||||
/* Returns true and stores OrderByExpr if the query has valid OrderBy section*/
|
||||
static bool matchRPNOrderBy(RPN & rpn, ANNExpression & expr);
|
||||
|
||||
/* Returns true if we have valid limit clause in query*/
|
||||
static bool matchRPNLimit(RPN & rpn, LimitExpression & expr);
|
||||
|
||||
/* Getting settings for ann_index_param */
|
||||
void parseSettings(const ASTPtr & node);
|
||||
|
||||
|
||||
/* Matches dist function, target vector, coloumn name */
|
||||
static bool matchMainParts(RPN::iterator & iter, RPN::iterator & end, ANNExpression & expr, bool & identifier_found);
|
||||
|
||||
// Util methods
|
||||
static void panicIfWrongBuiltRPN [[noreturn]] ();
|
||||
static String getIdentifierOrPanic(RPN::iterator& iter);
|
||||
|
||||
static float getFloatOrIntLiteralOrPanic(RPN::iterator& iter);
|
||||
|
||||
|
||||
// Here we store RPN-s for different types of Queries
|
||||
RPN rpn_prewhere_clause;
|
||||
RPN rpn_where_clause;
|
||||
RPN rpn_limit_clause;
|
||||
RPN rpn_order_by_clause;
|
||||
|
||||
Block block_with_constants;
|
||||
|
||||
ANNExprOpt ann_expr{std::nullopt};
|
||||
LimitExprOpt limit_expr{std::nullopt};
|
||||
String ann_index_params; // Empty string if no params
|
||||
|
||||
|
||||
bool order_by_query_type{false};
|
||||
bool where_query_type{false};
|
||||
bool has_limit{false};
|
||||
|
||||
// true if we had extracted ANNExpression from query
|
||||
bool index_is_useful{false};
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
@ -65,6 +65,11 @@ void AnnoyIndexSerialize<Dist>::deserialize(ReadBuffer& istr)
|
||||
Base::_built = true;
|
||||
}
|
||||
|
||||
template<typename Dist>
|
||||
float AnnoyIndexSerialize<Dist>::getSpaceDim() const {
|
||||
return Base::get_f();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -168,24 +173,21 @@ MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy(
|
||||
const IndexDescription & index,
|
||||
const SelectQueryInfo & query,
|
||||
ContextPtr context)
|
||||
: index_data_types(index.data_types)
|
||||
: condition(query, context)
|
||||
{
|
||||
RPN rpn = buildRPN(query, context);
|
||||
matchRPN(rpn);
|
||||
}
|
||||
|
||||
|
||||
bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const
|
||||
{
|
||||
// TODO: Change assert to the exception
|
||||
assert(expression.has_value());
|
||||
|
||||
std::vector<float> target_vec = expression.value().target;
|
||||
float min_distance = expression.value().distance;
|
||||
|
||||
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy>(idx_granule);
|
||||
auto annoy = std::dynamic_pointer_cast<Annoy::AnnoyIndexSerialize<>>(granule->index_base);
|
||||
|
||||
assert(condition.getMetric() == "L2Distance");
|
||||
assert(condition.getSpaceDim() == annoy->getSpaceDim());
|
||||
std::vector<float> target_vec = condition.getTargetVec();
|
||||
float max_distance = condition.getComparisonDistance();
|
||||
|
||||
std::vector<int32_t> items;
|
||||
std::vector<float> dist;
|
||||
items.reserve(1);
|
||||
@ -193,200 +195,13 @@ bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr i
|
||||
|
||||
// 1 - num of nearest neighbour (NN)
|
||||
// next number - upper limit on the size of the internal queue; -1 means, that it is equal to num of trees * num of NN
|
||||
annoy->get_nns_by_vector(&target_vec[0], 1, 200, &items, &dist);
|
||||
return dist[0] < min_distance;
|
||||
annoy->get_nns_by_vector(&target_vec[0], 1, -1, &items, &dist);
|
||||
return dist[0] < max_distance;
|
||||
}
|
||||
|
||||
bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
|
||||
{
|
||||
return !expression.has_value();
|
||||
}
|
||||
|
||||
MergeTreeIndexConditionAnnoy::RPN MergeTreeIndexConditionAnnoy::buildRPN(const SelectQueryInfo & query, ContextPtr context)
|
||||
{
|
||||
RPN rpn;
|
||||
|
||||
// Get block_with_constants for the future usage from query
|
||||
block_with_constants = KeyCondition::getBlockWithConstants(query.query, query.syntax_analyzer_result, context);
|
||||
|
||||
const auto & select = query.query->as<ASTSelectQuery &>();
|
||||
|
||||
// Sometimes our ANN expression in where can be placed in prewhere section
|
||||
// In this case we populate RPN from both source, but it can be dangerous in case
|
||||
// of some additional expressions in our query
|
||||
// We can either check prewhere or where, either match independently where and
|
||||
// prewhere
|
||||
// TODO: Need to think
|
||||
if (select.where())
|
||||
{
|
||||
traverseAST(select.where(), rpn);
|
||||
}
|
||||
if (select.prewhere())
|
||||
{
|
||||
traverseAST(select.prewhere(), rpn);
|
||||
}
|
||||
|
||||
// Return prefix rpn, so reverse the result
|
||||
std::reverse(rpn.begin(), rpn.end());
|
||||
return rpn;
|
||||
}
|
||||
|
||||
void MergeTreeIndexConditionAnnoy::traverseAST(const ASTPtr & node, RPN & rpn)
|
||||
{
|
||||
RPNElement element;
|
||||
|
||||
// We need to go deeper only if we have ASTFunction in this node
|
||||
if (const auto * func = node->as<ASTFunction>())
|
||||
{
|
||||
const ASTs & args = func->arguments->children;
|
||||
|
||||
// Traverse children
|
||||
for (const auto & arg : args)
|
||||
{
|
||||
traverseAST(arg, rpn);
|
||||
}
|
||||
}
|
||||
|
||||
// Extract information about current node and populate it in the element
|
||||
if (!traverseAtomAST(node, element)) {
|
||||
// If we cannot identify our node type
|
||||
element.function = RPNElement::FUNCTION_UNKNOWN;
|
||||
}
|
||||
|
||||
rpn.emplace_back(std::move(element));
|
||||
}
|
||||
|
||||
bool MergeTreeIndexConditionAnnoy::traverseAtomAST(const ASTPtr & node, RPNElement & out) {
|
||||
// Firstly check if we have contants behind the node
|
||||
{
|
||||
Field const_value;
|
||||
DataTypePtr const_type;
|
||||
|
||||
|
||||
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
|
||||
{
|
||||
/// Check constant type (use Float64 because all Fields implementation contains Float64 (for Float32 too))
|
||||
if (const_value.getType() == Field::Types::Float64)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
|
||||
out.literal.emplace(const_value.get<Float32>());
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Match function naming with a type
|
||||
if (const auto * function = node->as<ASTFunction>())
|
||||
{
|
||||
// TODO: Add support for other metrics
|
||||
if (function->name == "L2Distance")
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_DISTANCE;
|
||||
}
|
||||
else if (function->name == "tuple")
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_TUPLE;
|
||||
}
|
||||
else if (function->name == "less")
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_LESS;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
// Match identifier
|
||||
else if (const auto * identifier = node->as<ASTIdentifier>())
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_IDENTIFIER;
|
||||
out.identifier.emplace(identifier->name());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MergeTreeIndexConditionAnnoy::matchRPN(const RPN & rpn)
|
||||
{
|
||||
// Can we place it outside the function?
|
||||
// Use for match the rpn
|
||||
// Take care of matching tuples (because it can contains arbitary number of fields)
|
||||
RPN prefix_template_rpn{
|
||||
RPNElement{RPNElement::FUNCTION_LESS},
|
||||
RPNElement{RPNElement::FUNCTION_FLOAT_LITERAL},
|
||||
RPNElement{RPNElement::FUNCTION_DISTANCE},
|
||||
RPNElement{RPNElement::FUNCTION_TUPLE},
|
||||
RPNElement{RPNElement::FUNCTION_IDENTIFIER},
|
||||
};
|
||||
|
||||
// Placeholders for the extracted data
|
||||
Target target_vec;
|
||||
float distance = 0;
|
||||
|
||||
size_t rpn_idx = 0;
|
||||
size_t template_idx = 0;
|
||||
|
||||
// TODO: Should we check what we have the same size of RPNs?
|
||||
// If we wand to support complex expressions, we will not check it
|
||||
while (rpn_idx < rpn.size() && template_idx < prefix_template_rpn.size())
|
||||
{
|
||||
const auto & element = rpn[rpn_idx];
|
||||
const auto & template_element = prefix_template_rpn[template_idx];
|
||||
|
||||
if (element.function != template_element.function)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (element.function == RPNElement::FUNCTION_FLOAT_LITERAL)
|
||||
{
|
||||
assert(element.literal.has_value());
|
||||
auto value = element.literal.value();
|
||||
|
||||
distance = value;
|
||||
}
|
||||
|
||||
if (element.function == RPNElement::FUNCTION_TUPLE)
|
||||
{
|
||||
// TODO: Better tuple extraction
|
||||
// Extract target vec
|
||||
++rpn_idx;
|
||||
while (rpn_idx < rpn.size()) {
|
||||
if (rpn[rpn_idx].function == RPNElement::FUNCTION_FLOAT_LITERAL)
|
||||
{
|
||||
// Extract tuple element
|
||||
assert(rpn[rpn_idx].literal.has_value());
|
||||
auto value = rpn[rpn_idx].literal.value();
|
||||
target_vec.push_back(value);
|
||||
++rpn_idx;
|
||||
} else {
|
||||
++template_idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (element.function == RPNElement::FUNCTION_IDENTIFIER)
|
||||
{
|
||||
// TODO: Check that we have the same columns
|
||||
}
|
||||
|
||||
++rpn_idx;
|
||||
++template_idx;
|
||||
}
|
||||
|
||||
expression.emplace(ANNExpression{
|
||||
.target = std::move(target_vec),
|
||||
.distance = distance,
|
||||
});
|
||||
|
||||
return true;
|
||||
return condition.alwaysUnknownOrTrue();
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <Storates/MergeTree/CommonCondition.h>
|
||||
#include <Storages/MergeTree/MergeTreeIndices.h>
|
||||
#include <Storages/MergeTree/MergeTreeData.h>
|
||||
#include <Storages/MergeTree/KeyCondition.h>
|
||||
@ -24,6 +25,7 @@ namespace Annoy
|
||||
AnnoyIndexSerialize(const int dim) : Base::AnnoyIndex(dim) {}
|
||||
void serialize(WriteBuffer& ostr) const;
|
||||
void deserialize(ReadBuffer& istr);
|
||||
float gedSpaceDim() const;
|
||||
};
|
||||
}
|
||||
|
||||
@ -82,78 +84,9 @@ public:
|
||||
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
|
||||
|
||||
~MergeTreeIndexConditionAnnoy() override = default;
|
||||
|
||||
private:
|
||||
// Type of the vector to use as a target in the distance function
|
||||
using Target = std::vector<float>;
|
||||
|
||||
// Extracted data from the query like WHERE L2Distance(column_name, target) < distance
|
||||
struct ANNExpression {
|
||||
Target target;
|
||||
float distance;
|
||||
};
|
||||
|
||||
using ANNExpressionOpt = std::optional<ANNExpression>;
|
||||
|
||||
// Item of the Reverse Polish notation
|
||||
struct RPNElement
|
||||
{
|
||||
enum Function
|
||||
{
|
||||
// Atoms of an ANN expression
|
||||
|
||||
// Function like L2Distance
|
||||
FUNCTION_DISTANCE,
|
||||
|
||||
// Function like tuple(...)
|
||||
FUNCTION_TUPLE,
|
||||
|
||||
// Operator <
|
||||
FUNCTION_LESS,
|
||||
|
||||
// Numeric float value
|
||||
FUNCTION_FLOAT_LITERAL,
|
||||
|
||||
// Identifier of the column, e.g. L2Distance(number, target), number is a identifier of the column
|
||||
FUNCTION_IDENTIFIER,
|
||||
|
||||
FUNCTION_UNKNOWN, /// Can take any value.
|
||||
/// Operators of the logical expression.
|
||||
FUNCTION_NOT,
|
||||
FUNCTION_AND,
|
||||
FUNCTION_OR,
|
||||
};
|
||||
|
||||
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
|
||||
: function(function_)
|
||||
{}
|
||||
|
||||
Function function;
|
||||
|
||||
// TODO: Use not optional, but variant
|
||||
// Value for the FUNCTION_FLOAT_LITERAL
|
||||
std::optional<float> literal;
|
||||
|
||||
// Value for the FUNCTION_IDENTIDIER
|
||||
std::optional<String> identifier;
|
||||
};
|
||||
|
||||
using RPN = std::vector<RPNElement>;
|
||||
|
||||
// Build RPN of the query, return with copy ellision
|
||||
RPN buildRPN(const SelectQueryInfo & query, ContextPtr context);
|
||||
|
||||
// Util functions for the traversal of AST
|
||||
void traverseAST(const ASTPtr & node, RPN & rpn);
|
||||
// Return true if we can identify our node type
|
||||
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
|
||||
|
||||
// Check that rpn matches the template rpn (TODO: put template RPN outside this function)
|
||||
bool matchRPN(const RPN & rpn);
|
||||
|
||||
Block block_with_constants;
|
||||
|
||||
DataTypes index_data_types;
|
||||
ANNExpressionOpt expression;
|
||||
CommonCondition condition;
|
||||
};
|
||||
|
||||
|
||||
@ -179,6 +112,4 @@ public:
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,91 @@
|
||||
{
|
||||
"cmake" :
|
||||
{
|
||||
"generator" :
|
||||
{
|
||||
"name" : "Ninja"
|
||||
},
|
||||
"paths" :
|
||||
{
|
||||
"cmake" : "/usr/bin/cmake",
|
||||
"cpack" : "/usr/bin/cpack",
|
||||
"ctest" : "/usr/bin/ctest",
|
||||
"root" : "/usr/share/cmake-3.16"
|
||||
},
|
||||
"version" :
|
||||
{
|
||||
"isDirty" : false,
|
||||
"major" : 3,
|
||||
"minor" : 16,
|
||||
"patch" : 3,
|
||||
"string" : "3.16.3",
|
||||
"suffix" : ""
|
||||
}
|
||||
},
|
||||
"objects" :
|
||||
[
|
||||
{
|
||||
"jsonFile" : "codemodel-v2-bb290fe28ba2e684d61e.json",
|
||||
"kind" : "codemodel",
|
||||
"version" :
|
||||
{
|
||||
"major" : 2,
|
||||
"minor" : 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"jsonFile" : "cache-v2-efa0f9ba8e19226714e8.json",
|
||||
"kind" : "cache",
|
||||
"version" :
|
||||
{
|
||||
"major" : 2,
|
||||
"minor" : 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"jsonFile" : "cmakeFiles-v1-953c51e4923a86e8f869.json",
|
||||
"kind" : "cmakeFiles",
|
||||
"version" :
|
||||
{
|
||||
"major" : 1,
|
||||
"minor" : 0
|
||||
}
|
||||
}
|
||||
],
|
||||
"reply" :
|
||||
{
|
||||
"client-integration-vscode" :
|
||||
{
|
||||
"cache-v2" :
|
||||
{
|
||||
"jsonFile" : "cache-v2-efa0f9ba8e19226714e8.json",
|
||||
"kind" : "cache",
|
||||
"version" :
|
||||
{
|
||||
"major" : 2,
|
||||
"minor" : 0
|
||||
}
|
||||
},
|
||||
"cmakeFiles-v1" :
|
||||
{
|
||||
"jsonFile" : "cmakeFiles-v1-953c51e4923a86e8f869.json",
|
||||
"kind" : "cmakeFiles",
|
||||
"version" :
|
||||
{
|
||||
"major" : 1,
|
||||
"minor" : 0
|
||||
}
|
||||
},
|
||||
"codemodel-v2" :
|
||||
{
|
||||
"jsonFile" : "codemodel-v2-bb290fe28ba2e684d61e.json",
|
||||
"kind" : "codemodel",
|
||||
"version" :
|
||||
{
|
||||
"major" : 2,
|
||||
"minor" : 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user