ClickHouse/src/Common/FST.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

487 lines
13 KiB
C++
Raw Normal View History

#include "FST.h"
#include <algorithm>
2022-06-24 01:56:15 +00:00
#include <cassert>
#include <iostream>
#include <memory>
#include <vector>
#include <Common/Exception.h>
2023-01-10 16:26:27 +00:00
#include <city.h>
2022-06-24 01:56:15 +00:00
2023-01-11 21:40:20 +00:00
/// "paper" in the comments in this file refers to:
/// [Direct Construction of Minimal Acyclic Subsequential Transduers] by Stoyan Mihov and Denis Maurel, University of Tours, France
2022-06-24 01:56:15 +00:00
namespace DB
{
2023-01-20 09:32:36 +00:00
2022-06-24 01:56:15 +00:00
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
2022-06-24 01:56:15 +00:00
};
namespace FST
{
2023-01-20 09:32:36 +00:00
Arc::Arc(Output output_, const StatePtr & target_)
: output(output_)
, target(target_)
{}
UInt64 Arc::serialize(WriteBuffer & write_buffer) const
2022-06-24 01:56:15 +00:00
{
UInt64 written_bytes = 0;
2022-06-24 01:56:15 +00:00
bool has_output = output != 0;
/// First UInt64 is target_index << 1 + has_output
assert(target != nullptr);
2022-06-24 01:56:15 +00:00
UInt64 first = ((target->state_index) << 1) + has_output;
writeVarUInt(first, write_buffer);
written_bytes += getLengthOfVarUInt(first);
/// Second UInt64 is output (optional based on whether has_output is not zero)
if (has_output)
{
writeVarUInt(output, write_buffer);
written_bytes += getLengthOfVarUInt(output);
}
return written_bytes;
}
bool operator==(const Arc & arc1, const Arc & arc2)
2022-06-24 01:56:15 +00:00
{
2023-01-10 16:26:27 +00:00
assert(arc1.target != nullptr && arc2.target != nullptr);
return (arc1.output == arc2.output && arc1.target->id == arc2.target->id);
}
2023-01-10 16:26:27 +00:00
void LabelsAsBitmap::addLabel(char label)
{
UInt8 index = label;
2022-06-24 01:56:15 +00:00
UInt256 bit_label = 1;
bit_label <<= index;
data |= bit_label;
}
2023-01-10 16:26:27 +00:00
UInt64 LabelsAsBitmap::getIndex(char label) const
2022-07-03 12:18:51 +00:00
{
2023-01-10 16:26:27 +00:00
UInt64 bit_count = 0;
2022-06-24 01:56:15 +00:00
UInt8 index = label;
2022-06-24 01:56:15 +00:00
int which_int64 = 0;
while (true)
{
if (index < 64)
{
UInt64 mask = index == 63 ? (-1) : (1ULL << (index + 1)) - 1;
2022-06-24 01:56:15 +00:00
bit_count += std::popcount(mask & data.items[which_int64]);
2022-06-24 01:56:15 +00:00
break;
}
index -= 64;
bit_count += std::popcount(data.items[which_int64]);
2022-06-24 01:56:15 +00:00
which_int64++;
}
2022-07-19 20:15:59 +00:00
return bit_count;
2022-06-24 01:56:15 +00:00
}
2023-01-20 09:32:36 +00:00
UInt64 LabelsAsBitmap::serialize(WriteBuffer & write_buffer)
2023-01-10 16:26:27 +00:00
{
writeVarUInt(data.items[0], write_buffer);
writeVarUInt(data.items[1], write_buffer);
writeVarUInt(data.items[2], write_buffer);
writeVarUInt(data.items[3], write_buffer);
return getLengthOfVarUInt(data.items[0])
+ getLengthOfVarUInt(data.items[1])
+ getLengthOfVarUInt(data.items[2])
+ getLengthOfVarUInt(data.items[3]);
}
bool LabelsAsBitmap::hasLabel(char label) const
2022-06-24 01:56:15 +00:00
{
UInt8 index = label;
2022-06-24 01:56:15 +00:00
UInt256 bit_label = 1;
bit_label <<= index;
return ((data & bit_label) != 0);
}
2023-01-20 09:32:36 +00:00
Arc * State::getArc(char label) const
2022-06-24 01:56:15 +00:00
{
auto it = arcs.find(label);
2023-01-20 09:32:36 +00:00
if (it == arcs.end())
2022-06-24 01:56:15 +00:00
return nullptr;
return const_cast<Arc *>(&it->second);
2022-06-24 01:56:15 +00:00
}
void State::addArc(char label, Output output, StatePtr target)
{
arcs[label] = Arc(output, target);
}
2023-01-10 16:26:27 +00:00
void State::clear()
{
id = 0;
state_index = 0;
arcs.clear();
2023-01-20 09:32:36 +00:00
flag = 0;
2023-01-10 16:26:27 +00:00
}
2022-06-24 01:56:15 +00:00
UInt64 State::hash() const
{
std::vector<char> values;
values.reserve(arcs.size() * (sizeof(Output) + sizeof(UInt64) + 1));
2023-01-20 09:32:36 +00:00
for (const auto & [label, arc] : arcs)
2022-06-24 01:56:15 +00:00
{
values.push_back(label);
2023-01-20 09:32:36 +00:00
const auto * ptr = reinterpret_cast<const char *>(&arc.output);
2022-06-24 01:56:15 +00:00
std::copy(ptr, ptr + sizeof(Output), std::back_inserter(values));
2023-01-20 09:32:36 +00:00
ptr = reinterpret_cast<const char *>(&arc.target->id);
2022-06-24 01:56:15 +00:00
std::copy(ptr, ptr + sizeof(UInt64), std::back_inserter(values));
}
return CityHash_v1_0_2::CityHash64(values.data(), values.size());
}
2023-01-20 09:32:36 +00:00
bool operator==(const State & state1, const State & state2)
2022-06-24 01:56:15 +00:00
{
if (state1.arcs.size() != state2.arcs.size())
return false;
for (const auto & [label, arc] : state1.arcs)
2022-06-24 01:56:15 +00:00
{
const auto it = state2.arcs.find(label);
2023-01-20 09:32:36 +00:00
if (it == state2.arcs.end())
2022-06-24 01:56:15 +00:00
return false;
if (it->second != arc)
2022-06-24 01:56:15 +00:00
return false;
}
return true;
}
2023-01-20 09:32:36 +00:00
UInt64 State::serialize(WriteBuffer & write_buffer)
2022-06-24 01:56:15 +00:00
{
UInt64 written_bytes = 0;
/// Serialize flag
write_buffer.write(flag);
written_bytes += 1;
2023-01-10 16:26:27 +00:00
if (getEncodingMethod() == EncodingMethod::Sequential)
2022-06-24 01:56:15 +00:00
{
/// Serialize all labels
std::vector<char> labels;
2023-01-10 16:26:27 +00:00
labels.reserve(arcs.size());
2022-06-24 01:56:15 +00:00
2023-01-20 09:32:36 +00:00
for (auto & [label, state] : arcs)
labels.push_back(label);
2022-06-24 01:56:15 +00:00
UInt8 label_size = labels.size();
write_buffer.write(label_size);
written_bytes += 1;
write_buffer.write(labels.data(), labels.size());
written_bytes += labels.size();
/// Serialize all arcs
for (char label : labels)
{
2023-01-20 09:32:36 +00:00
Arc * arc = getArc(label);
2022-06-24 01:56:15 +00:00
assert(arc != nullptr);
written_bytes += arc->serialize(write_buffer);
}
}
else
{
/// Serialize bitmap
2023-01-10 16:26:27 +00:00
LabelsAsBitmap bmp;
for (auto & [label, state] : arcs)
2023-01-10 16:26:27 +00:00
bmp.addLabel(label);
written_bytes += bmp.serialize(write_buffer);
2022-06-24 01:56:15 +00:00
/// Serialize all arcs
for (auto & [label, state] : arcs)
2022-06-24 01:56:15 +00:00
{
2023-01-20 09:32:36 +00:00
Arc * arc = getArc(label);
2022-06-24 01:56:15 +00:00
assert(arc != nullptr);
written_bytes += arc->serialize(write_buffer);
}
2022-06-24 01:56:15 +00:00
}
return written_bytes;
}
2023-01-20 09:32:36 +00:00
void State::readFlag(ReadBuffer & read_buffer)
{
read_buffer.readStrict(reinterpret_cast<char &>(flag));
}
FstBuilder::FstBuilder(WriteBuffer & write_buffer_) : write_buffer(write_buffer_)
2022-06-24 01:56:15 +00:00
{
2022-07-19 20:15:59 +00:00
for (auto & temp_state : temp_states)
temp_state = std::make_shared<State>();
2022-06-24 01:56:15 +00:00
}
2023-01-11 21:40:20 +00:00
/// See FindMinimized in the paper pseudo code l11-l21.
2023-01-20 09:32:36 +00:00
StatePtr FstBuilder::findMinimized(const State & state, bool & found)
2022-06-24 01:56:15 +00:00
{
found = false;
auto hash = state.hash();
2023-01-11 21:40:20 +00:00
/// MEMBER: in the paper pseudo code l15
auto it = minimized_states.find(hash);
2022-06-24 01:56:15 +00:00
2023-01-20 09:32:36 +00:00
if (it != minimized_states.end() && *it->second == state)
2022-06-24 01:56:15 +00:00
{
found = true;
return it->second;
}
2023-01-11 21:40:20 +00:00
/// COPY_STATE: in the paper pseudo code l17
StatePtr p = std::make_shared<State>(state);
2023-01-11 21:40:20 +00:00
/// INSERT: in the paper pseudo code l18
2022-06-24 01:56:15 +00:00
minimized_states[hash] = p;
return p;
}
2023-01-20 09:32:36 +00:00
namespace
{
2023-01-11 21:40:20 +00:00
/// See the paper pseudo code l33-34.
2023-01-20 09:32:36 +00:00
size_t getCommonPrefixLength(std::string_view word1, std::string_view word2)
2022-06-24 01:56:15 +00:00
{
size_t i = 0;
while (i < word1.size() && i < word2.size() && word1[i] == word2[i])
i++;
return i;
}
2023-01-20 09:32:36 +00:00
}
2023-01-11 21:40:20 +00:00
/// See the paper pseudo code l33-39 and l70-72(when down_to is 0).
2023-01-20 09:32:36 +00:00
void FstBuilder::minimizePreviousWordSuffix(Int64 down_to)
2022-06-24 01:56:15 +00:00
{
for (Int64 i = static_cast<Int64>(previous_word.size()); i >= down_to; --i)
2022-06-24 01:56:15 +00:00
{
2023-01-10 16:26:27 +00:00
bool found = false;
auto minimized_state = findMinimized(*temp_states[i], found);
2022-06-24 01:56:15 +00:00
if (i != 0)
{
Output output = 0;
2023-01-20 09:32:36 +00:00
Arc * arc = temp_states[i - 1]->getArc(previous_word[i - 1]);
2022-06-24 01:56:15 +00:00
if (arc)
output = arc->output;
2023-01-11 21:40:20 +00:00
/// SET_TRANSITION
2022-06-24 01:56:15 +00:00
temp_states[i - 1]->addArc(previous_word[i - 1], output, minimized_state);
}
if (minimized_state->id == 0)
minimized_state->id = next_id++;
if (i > 0 && temp_states[i - 1]->id == 0)
temp_states[i - 1]->id = next_id++;
if (!found)
2022-07-03 12:18:51 +00:00
{
2022-06-24 01:56:15 +00:00
minimized_state->state_index = previous_state_index;
previous_written_bytes = minimized_state->serialize(write_buffer);
previous_state_index += previous_written_bytes;
}
}
}
2023-01-20 09:32:36 +00:00
void FstBuilder::add(std::string_view current_word, Output current_output)
2022-06-24 01:56:15 +00:00
{
/// We assume word size is no greater than MAX_TERM_LENGTH(256).
/// FSTs without word size limitation would be inefficient and easy to cause memory bloat
/// Note that when using "split" tokenizer, if a granule has tokens which are longer than
/// MAX_TERM_LENGTH, the granule cannot be dropped and will be fully-scanned. It doesn't affect "ngram" tokenizers.
/// Another limitation is that if the query string has tokens which exceed this length
/// it will fallback to default searching when using "split" tokenizers.
2023-01-20 09:32:36 +00:00
size_t current_word_len = current_word.size();
2022-06-24 01:56:15 +00:00
if (current_word_len > MAX_TERM_LENGTH)
2023-01-20 09:32:36 +00:00
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Cannot build inverted index: The maximum term length is {}, this is exceeded by term {}", MAX_TERM_LENGTH, current_word_len);
2022-06-24 01:56:15 +00:00
2023-01-10 16:26:27 +00:00
size_t prefix_length_plus1 = getCommonPrefixLength(current_word, previous_word) + 1;
2022-06-24 01:56:15 +00:00
minimizePreviousWordSuffix(prefix_length_plus1);
2023-01-11 21:40:20 +00:00
/// Initialize the tail state, see paper pseudo code l39-43
2023-01-10 16:26:27 +00:00
for (size_t i = prefix_length_plus1; i <= current_word.size(); ++i)
2022-06-24 01:56:15 +00:00
{
2023-01-11 21:40:20 +00:00
/// CLEAR_STATE: l41
2022-06-24 01:56:15 +00:00
temp_states[i]->clear();
2023-01-11 21:40:20 +00:00
/// SET_TRANSITION: l42
temp_states[i - 1]->addArc(current_word[i - 1], 0, temp_states[i]);
2022-06-24 01:56:15 +00:00
}
/// We assume the current word is different with previous word
2023-01-11 21:40:20 +00:00
/// See paper pseudo code l44-47
temp_states[current_word_len]->setFinal(true);
2023-01-11 21:40:20 +00:00
2022-06-24 01:56:15 +00:00
/// Adjust outputs on the arcs
2023-01-11 21:40:20 +00:00
/// See paper pseudo code l48-63
2023-01-10 16:26:27 +00:00
for (size_t i = 1; i <= prefix_length_plus1 - 1; ++i)
2022-06-24 01:56:15 +00:00
{
2023-01-10 16:26:27 +00:00
Arc * arc_ptr = temp_states[i - 1]->getArc(current_word[i - 1]);
2022-06-24 01:56:15 +00:00
assert(arc_ptr != nullptr);
Output common_prefix = std::min(arc_ptr->output, current_output);
Output word_suffix = arc_ptr->output - common_prefix;
2022-06-24 01:56:15 +00:00
arc_ptr->output = common_prefix;
/// For each arc, adjust its output
if (word_suffix != 0)
{
2023-01-10 16:26:27 +00:00
for (auto & [label, arc] : temp_states[i]->arcs)
2022-06-24 01:56:15 +00:00
arc.output += word_suffix;
}
/// Reduce current_output
current_output -= common_prefix;
}
/// Set last temp state's output
2023-01-11 21:40:20 +00:00
/// paper pseudo code l66-67 (assuming CurrentWord != PreviousWorld)
Arc * arc = temp_states[prefix_length_plus1 - 1]->getArc(current_word[prefix_length_plus1 - 1]);
2022-06-24 01:56:15 +00:00
assert(arc != nullptr);
arc->output = current_output;
previous_word = current_word;
}
2023-01-20 09:32:36 +00:00
UInt64 FstBuilder::build()
2022-06-24 01:56:15 +00:00
{
minimizePreviousWordSuffix(0);
/// Save initial state index
previous_state_index -= previous_written_bytes;
UInt8 length = getLengthOfVarUInt(previous_state_index);
writeVarUInt(previous_state_index, write_buffer);
write_buffer.write(length);
return previous_state_index + previous_written_bytes + length + 1;
}
2023-01-20 09:32:36 +00:00
FiniteStateTransducer::FiniteStateTransducer(std::vector<UInt8> data_)
: data(std::move(data_))
2022-06-24 01:56:15 +00:00
{
}
void FiniteStateTransducer::clear()
2022-06-24 01:56:15 +00:00
{
data.clear();
}
2023-01-20 09:32:36 +00:00
std::pair<UInt64, bool> FiniteStateTransducer::getOutput(std::string_view term)
2022-06-24 01:56:15 +00:00
{
2023-01-20 09:32:36 +00:00
std::pair<UInt64, bool> result(0, false);
2023-01-10 16:26:27 +00:00
2022-06-24 01:56:15 +00:00
/// Read index of initial state
ReadBufferFromMemory read_buffer(data.data(), data.size());
2023-01-20 09:32:36 +00:00
read_buffer.seek(data.size() - 1, SEEK_SET);
2022-06-24 01:56:15 +00:00
2023-01-20 09:32:36 +00:00
UInt8 length = 0;
read_buffer.readStrict(reinterpret_cast<char &>(length));
2023-01-10 16:26:27 +00:00
/// FST contains no terms
if (length == 0)
2023-01-20 09:32:36 +00:00
return {0, false};
2022-06-24 01:56:15 +00:00
read_buffer.seek(data.size() - 1 - length, SEEK_SET);
2023-01-20 09:32:36 +00:00
UInt64 state_index = 0;
2022-06-24 01:56:15 +00:00
readVarUInt(state_index, read_buffer);
for (size_t i = 0; i <= term.size(); ++i)
{
2023-01-20 09:32:36 +00:00
UInt64 arc_output = 0;
2022-06-24 01:56:15 +00:00
/// Read flag
State temp_state;
read_buffer.seek(state_index, SEEK_SET);
temp_state.readFlag(read_buffer);
2022-06-24 01:56:15 +00:00
if (i == term.size())
{
2023-01-10 16:26:27 +00:00
result.second = temp_state.isFinal();
2022-06-24 01:56:15 +00:00
break;
}
UInt8 label = term[i];
2023-01-10 16:26:27 +00:00
if (temp_state.getEncodingMethod() == State::EncodingMethod::Sequential)
2022-06-24 01:56:15 +00:00
{
/// Read number of labels
2023-01-20 09:32:36 +00:00
UInt8 label_num = 0;
read_buffer.readStrict(reinterpret_cast<char &>(label_num));
2022-06-24 01:56:15 +00:00
2022-07-03 12:18:51 +00:00
if (label_num == 0)
2023-01-20 09:32:36 +00:00
return {0, false};
2022-06-24 01:56:15 +00:00
auto labels_position = read_buffer.getPosition();
/// Find the index of the label from "labels" bytes
2023-01-20 09:32:36 +00:00
auto begin_it = data.begin() + labels_position;
auto end_it = data.begin() + labels_position + label_num;
2022-06-24 01:56:15 +00:00
auto pos = std::find(begin_it, end_it, label);
if (pos == end_it)
2023-01-20 09:32:36 +00:00
return {0, false};
2022-06-24 01:56:15 +00:00
/// Read the arc for the label
UInt64 arc_index = (pos - begin_it);
auto arcs_start_postion = labels_position + label_num;
read_buffer.seek(arcs_start_postion, SEEK_SET);
for (size_t j = 0; j <= arc_index; j++)
{
state_index = 0;
arc_output = 0;
readVarUInt(state_index, read_buffer);
if (state_index & 0x1) // output is followed
readVarUInt(arc_output, read_buffer);
state_index >>= 1;
}
}
else
{
2023-01-10 16:26:27 +00:00
LabelsAsBitmap bmp;
2022-06-24 01:56:15 +00:00
readVarUInt(bmp.data.items[0], read_buffer);
readVarUInt(bmp.data.items[1], read_buffer);
readVarUInt(bmp.data.items[2], read_buffer);
readVarUInt(bmp.data.items[3], read_buffer);
2023-01-10 16:26:27 +00:00
if (!bmp.hasLabel(label))
2023-01-20 09:32:36 +00:00
return {0, false};
2022-06-24 01:56:15 +00:00
/// Read the arc for the label
size_t arc_index = bmp.getIndex(label);
for (size_t j = 0; j < arc_index; j++)
{
state_index = 0;
arc_output = 0;
readVarUInt(state_index, read_buffer);
if (state_index & 0x1) // output is followed
readVarUInt(arc_output, read_buffer);
state_index >>= 1;
}
}
/// Accumulate the output value
2023-01-10 16:26:27 +00:00
result.first += arc_output;
2022-06-24 01:56:15 +00:00
}
2023-01-10 16:26:27 +00:00
return result;
2022-06-24 01:56:15 +00:00
}
2023-01-20 09:32:36 +00:00
2022-06-24 01:56:15 +00:00
}
2023-01-20 09:32:36 +00:00
}