2022-04-01 18:47:48 +00:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <Storages/MergeTree/MergeTreeIndices.h>
|
|
|
|
#include <Storages/MergeTree/MergeTreeData.h>
|
|
|
|
#include <Storages/MergeTree/KeyCondition.h>
|
|
|
|
|
|
|
|
#include <annoylib.h>
|
|
|
|
#include <kissrandom.h>
|
|
|
|
|
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
|
|
namespace Annoy
|
|
|
|
{
|
|
|
|
using AnnoyIndexThreadedBuildPolicy = ::Annoy::AnnoyIndexSingleThreadedBuildPolicy;
|
|
|
|
template <typename Dist = ::Annoy::Euclidean>
|
|
|
|
class AnnoyIndexSerialize : public ::Annoy::AnnoyIndex<Int32, Float32, Dist, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>
|
|
|
|
{
|
|
|
|
using Base = ::Annoy::AnnoyIndex<Int32, Float32, Dist, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
|
|
|
|
public:
|
|
|
|
AnnoyIndexSerialize() = delete;
|
|
|
|
AnnoyIndexSerialize(const int dim) : Base::AnnoyIndex(dim) {}
|
|
|
|
void serialize(WriteBuffer& ostr) const;
|
|
|
|
void deserialize(ReadBuffer& istr);
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
|
|
|
|
{
|
|
|
|
using AnnoyIndex = Annoy::AnnoyIndexSerialize<>;
|
|
|
|
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
|
|
|
|
|
|
|
|
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_);
|
|
|
|
MergeTreeIndexGranuleAnnoy(
|
|
|
|
const String & index_name_,
|
|
|
|
const Block & index_sample_block_,
|
|
|
|
AnnoyIndexPtr index_base_);
|
|
|
|
|
|
|
|
~MergeTreeIndexGranuleAnnoy() override = default;
|
|
|
|
|
|
|
|
void serializeBinary(WriteBuffer & ostr) const override;
|
|
|
|
void deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion version) override;
|
|
|
|
|
|
|
|
bool empty() const override;
|
|
|
|
|
|
|
|
String index_name;
|
|
|
|
Block index_sample_block;
|
|
|
|
AnnoyIndexPtr index_base;
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator
|
|
|
|
{
|
|
|
|
using AnnoyIndex = Annoy::AnnoyIndexSerialize<>;
|
|
|
|
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
|
|
|
|
|
|
|
|
MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block);
|
|
|
|
~MergeTreeIndexAggregatorAnnoy() override = default;
|
|
|
|
|
|
|
|
bool empty() const override;
|
|
|
|
MergeTreeIndexGranulePtr getGranuleAndReset() override;
|
|
|
|
void update(const Block & block, size_t * pos, size_t limit) override;
|
|
|
|
|
|
|
|
String index_name;
|
|
|
|
Block index_sample_block;
|
|
|
|
AnnoyIndexPtr index_base;
|
|
|
|
};
|
|
|
|
|
|
|
|
|
2022-04-15 15:45:59 +00:00
|
|
|
class MergeTreeIndexConditionAnnoy final : public IMergeTreeIndexCondition
|
2022-04-01 18:47:48 +00:00
|
|
|
{
|
|
|
|
public:
|
2022-04-15 12:08:54 +00:00
|
|
|
MergeTreeIndexConditionAnnoy(
|
|
|
|
const IndexDescription & index,
|
|
|
|
const SelectQueryInfo & query,
|
|
|
|
ContextPtr context);
|
2022-04-01 18:47:48 +00:00
|
|
|
|
2022-04-15 12:08:54 +00:00
|
|
|
bool alwaysUnknownOrTrue() const override;
|
2022-04-01 18:47:48 +00:00
|
|
|
|
|
|
|
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
|
|
|
|
|
|
|
|
~MergeTreeIndexConditionAnnoy() override = default;
|
|
|
|
private:
|
2022-04-15 12:08:54 +00:00
|
|
|
// Type of the vector to use as a target in the distance function
|
|
|
|
using Target = std::vector<float>;
|
|
|
|
|
|
|
|
// Extracted data from the query like WHERE L2Distance(column_name, target) < distance
|
|
|
|
struct ANNExpression {
|
|
|
|
Target target;
|
|
|
|
float distance;
|
|
|
|
};
|
|
|
|
|
|
|
|
using ANNExpressionOpt = std::optional<ANNExpression>;
|
|
|
|
|
|
|
|
// Item of the Reverse Polish notation
|
|
|
|
struct RPNElement
|
|
|
|
{
|
|
|
|
enum Function
|
|
|
|
{
|
|
|
|
// Atoms of an ANN expression
|
|
|
|
|
|
|
|
// Function like L2Distance
|
|
|
|
FUNCTION_DISTANCE,
|
|
|
|
|
|
|
|
// Function like tuple(...)
|
|
|
|
FUNCTION_TUPLE,
|
|
|
|
|
|
|
|
// Operator <
|
|
|
|
FUNCTION_LESS,
|
|
|
|
|
|
|
|
// Numeric float value
|
|
|
|
FUNCTION_FLOAT_LITERAL,
|
|
|
|
|
|
|
|
// Identifier of the column, e.g. L2Distance(number, target), number is a identifier of the column
|
|
|
|
FUNCTION_IDENTIFIER,
|
|
|
|
|
|
|
|
FUNCTION_UNKNOWN, /// Can take any value.
|
|
|
|
/// Operators of the logical expression.
|
|
|
|
FUNCTION_NOT,
|
|
|
|
FUNCTION_AND,
|
|
|
|
FUNCTION_OR,
|
|
|
|
};
|
|
|
|
|
|
|
|
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
|
|
|
|
: function(function_)
|
|
|
|
{}
|
|
|
|
|
|
|
|
Function function;
|
|
|
|
|
|
|
|
// TODO: Use not optional, but variant
|
|
|
|
// Value for the FUNCTION_FLOAT_LITERAL
|
|
|
|
std::optional<float> literal;
|
|
|
|
|
|
|
|
// Value for the FUNCTION_IDENTIDIER
|
|
|
|
std::optional<String> identifier;
|
|
|
|
};
|
|
|
|
|
|
|
|
using RPN = std::vector<RPNElement>;
|
|
|
|
|
|
|
|
// Build RPN of the query, return with copy ellision
|
|
|
|
RPN buildRPN(const SelectQueryInfo & query, ContextPtr context);
|
|
|
|
|
|
|
|
// Util functions for the traversal of AST
|
|
|
|
void traverseAST(const ASTPtr & node, RPN & rpn);
|
|
|
|
// Return true if we can identify our node type
|
|
|
|
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
|
|
|
|
|
|
|
|
// Check that rpn matches the template rpn (TODO: put template RPN outside this function)
|
|
|
|
bool matchRPN(const RPN & rpn);
|
|
|
|
|
|
|
|
Block block_with_constants;
|
|
|
|
|
|
|
|
DataTypes index_data_types;
|
|
|
|
ANNExpressionOpt expression;
|
2022-04-01 18:47:48 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class MergeTreeIndexAnnoy : public IMergeTreeIndex
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
explicit MergeTreeIndexAnnoy(const IndexDescription & index_)
|
|
|
|
: IMergeTreeIndex(index_)
|
|
|
|
{}
|
|
|
|
|
|
|
|
~MergeTreeIndexAnnoy() override = default;
|
|
|
|
|
|
|
|
MergeTreeIndexGranulePtr createIndexGranule() const override;
|
|
|
|
MergeTreeIndexAggregatorPtr createIndexAggregator() const override;
|
|
|
|
|
|
|
|
MergeTreeIndexConditionPtr createIndexCondition(
|
|
|
|
const SelectQueryInfo & query, ContextPtr context) const override;
|
|
|
|
|
|
|
|
bool mayBenefitFromIndexForIn(const ASTPtr & /*node*/) const override { return true; }
|
|
|
|
|
|
|
|
const char* getSerializedFileExtension() const override { return ".idx2"; }
|
|
|
|
MergeTreeIndexFormat getDeserializedFormat(const DiskPtr disk, const std::string & path_prefix) const override;
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|