2022-09-07 18:22:09 +00:00
# include "FST.h"
# include <algorithm>
2022-06-24 01:56:15 +00:00
# include <cassert>
# include <memory>
# include <vector>
# include <Common/Exception.h>
2023-01-10 16:26:27 +00:00
# include <city.h>
2022-06-24 01:56:15 +00:00
2023-01-11 21:40:20 +00:00
/// "paper" in the comments in this file refers to:
/// [Direct Construction of Minimal Acyclic Subsequential Transduers] by Stoyan Mihov and Denis Maurel, University of Tours, France
2022-06-24 01:56:15 +00:00
namespace DB
{
2023-01-20 09:32:36 +00:00
2022-06-24 01:56:15 +00:00
namespace ErrorCodes
{
2022-09-28 14:28:28 +00:00
extern const int BAD_ARGUMENTS ;
2022-06-24 01:56:15 +00:00
} ;
2022-09-07 18:22:09 +00:00
namespace FST
{
2023-01-20 09:32:36 +00:00
Arc : : Arc ( Output output_ , const StatePtr & target_ )
: output ( output_ )
, target ( target_ )
{ }
UInt64 Arc : : serialize ( WriteBuffer & write_buffer ) const
2022-06-24 01:56:15 +00:00
{
2022-09-25 23:29:30 +00:00
UInt64 written_bytes = 0 ;
2022-06-24 01:56:15 +00:00
bool has_output = output ! = 0 ;
/// First UInt64 is target_index << 1 + has_output
2022-09-25 23:29:30 +00:00
assert ( target ! = nullptr ) ;
2022-06-24 01:56:15 +00:00
UInt64 first = ( ( target - > state_index ) < < 1 ) + has_output ;
writeVarUInt ( first , write_buffer ) ;
written_bytes + = getLengthOfVarUInt ( first ) ;
/// Second UInt64 is output (optional based on whether has_output is not zero)
if ( has_output )
{
writeVarUInt ( output , write_buffer ) ;
written_bytes + = getLengthOfVarUInt ( output ) ;
}
return written_bytes ;
}
2022-09-25 23:29:30 +00:00
bool operator = = ( const Arc & arc1 , const Arc & arc2 )
2022-06-24 01:56:15 +00:00
{
2023-01-10 16:26:27 +00:00
assert ( arc1 . target ! = nullptr & & arc2 . target ! = nullptr ) ;
2022-09-25 23:29:30 +00:00
return ( arc1 . output = = arc2 . output & & arc1 . target - > id = = arc2 . target - > id ) ;
}
2023-01-10 16:26:27 +00:00
void LabelsAsBitmap : : addLabel ( char label )
2022-09-25 23:29:30 +00:00
{
UInt8 index = label ;
2022-06-24 01:56:15 +00:00
UInt256 bit_label = 1 ;
bit_label < < = index ;
data | = bit_label ;
}
2023-01-20 10:56:20 +00:00
bool LabelsAsBitmap : : hasLabel ( char label ) const
{
UInt8 index = label ;
UInt256 bit_label = 1 ;
bit_label < < = index ;
return ( ( data & bit_label ) ! = 0 ) ;
}
2023-01-10 16:26:27 +00:00
UInt64 LabelsAsBitmap : : getIndex ( char label ) const
2022-07-03 12:18:51 +00:00
{
2023-01-10 16:26:27 +00:00
UInt64 bit_count = 0 ;
2022-06-24 01:56:15 +00:00
2022-09-25 23:29:30 +00:00
UInt8 index = label ;
2022-06-24 01:56:15 +00:00
int which_int64 = 0 ;
while ( true )
{
if ( index < 64 )
{
2022-09-25 23:29:30 +00:00
UInt64 mask = index = = 63 ? ( - 1 ) : ( 1ULL < < ( index + 1 ) ) - 1 ;
2022-06-24 01:56:15 +00:00
2022-09-25 23:29:30 +00:00
bit_count + = std : : popcount ( mask & data . items [ which_int64 ] ) ;
2022-06-24 01:56:15 +00:00
break ;
}
index - = 64 ;
2022-09-25 23:29:30 +00:00
bit_count + = std : : popcount ( data . items [ which_int64 ] ) ;
2022-06-24 01:56:15 +00:00
which_int64 + + ;
}
2022-07-19 20:15:59 +00:00
return bit_count ;
2022-06-24 01:56:15 +00:00
}
2023-01-20 09:32:36 +00:00
UInt64 LabelsAsBitmap : : serialize ( WriteBuffer & write_buffer )
2023-01-10 16:26:27 +00:00
{
writeVarUInt ( data . items [ 0 ] , write_buffer ) ;
writeVarUInt ( data . items [ 1 ] , write_buffer ) ;
writeVarUInt ( data . items [ 2 ] , write_buffer ) ;
writeVarUInt ( data . items [ 3 ] , write_buffer ) ;
return getLengthOfVarUInt ( data . items [ 0 ] )
+ getLengthOfVarUInt ( data . items [ 1 ] )
+ getLengthOfVarUInt ( data . items [ 2 ] )
+ getLengthOfVarUInt ( data . items [ 3 ] ) ;
}
2023-01-20 10:56:20 +00:00
UInt64 State : : hash ( ) const
2022-06-24 01:56:15 +00:00
{
2023-01-20 10:56:20 +00:00
std : : vector < char > values ;
values . reserve ( arcs . size ( ) * ( sizeof ( Output ) + sizeof ( UInt64 ) + 1 ) ) ;
for ( const auto & [ label , arc ] : arcs )
{
values . push_back ( label ) ;
const auto * ptr = reinterpret_cast < const char * > ( & arc . output ) ;
std : : copy ( ptr , ptr + sizeof ( Output ) , std : : back_inserter ( values ) ) ;
ptr = reinterpret_cast < const char * > ( & arc . target - > id ) ;
std : : copy ( ptr , ptr + sizeof ( UInt64 ) , std : : back_inserter ( values ) ) ;
}
return CityHash_v1_0_2 : : CityHash64 ( values . data ( ) , values . size ( ) ) ;
2022-06-24 01:56:15 +00:00
}
2023-01-20 09:32:36 +00:00
Arc * State : : getArc ( char label ) const
2022-06-24 01:56:15 +00:00
{
2022-09-25 23:29:30 +00:00
auto it = arcs . find ( label ) ;
2023-01-20 09:32:36 +00:00
if ( it = = arcs . end ( ) )
2022-06-24 01:56:15 +00:00
return nullptr ;
2023-01-17 14:29:13 +00:00
return const_cast < Arc * > ( & it - > second ) ;
2022-06-24 01:56:15 +00:00
}
void State : : addArc ( char label , Output output , StatePtr target )
{
arcs [ label ] = Arc ( output , target ) ;
}
2023-01-10 16:26:27 +00:00
void State : : clear ( )
{
id = 0 ;
state_index = 0 ;
arcs . clear ( ) ;
2023-01-20 09:32:36 +00:00
flag = 0 ;
2023-01-10 16:26:27 +00:00
}
2023-01-20 09:32:36 +00:00
UInt64 State : : serialize ( WriteBuffer & write_buffer )
2022-06-24 01:56:15 +00:00
{
UInt64 written_bytes = 0 ;
/// Serialize flag
write_buffer . write ( flag ) ;
written_bytes + = 1 ;
2023-01-10 16:26:27 +00:00
if ( getEncodingMethod ( ) = = EncodingMethod : : Sequential )
2022-06-24 01:56:15 +00:00
{
/// Serialize all labels
std : : vector < char > labels ;
2023-01-10 16:26:27 +00:00
labels . reserve ( arcs . size ( ) ) ;
2022-06-24 01:56:15 +00:00
2023-01-20 09:32:36 +00:00
for ( auto & [ label , state ] : arcs )
2022-09-25 23:29:30 +00:00
labels . push_back ( label ) ;
2022-06-24 01:56:15 +00:00
UInt8 label_size = labels . size ( ) ;
write_buffer . write ( label_size ) ;
written_bytes + = 1 ;
write_buffer . write ( labels . data ( ) , labels . size ( ) ) ;
written_bytes + = labels . size ( ) ;
/// Serialize all arcs
for ( char label : labels )
{
2023-01-20 09:32:36 +00:00
Arc * arc = getArc ( label ) ;
2022-06-24 01:56:15 +00:00
assert ( arc ! = nullptr ) ;
written_bytes + = arc - > serialize ( write_buffer ) ;
}
}
else
{
/// Serialize bitmap
2023-01-10 16:26:27 +00:00
LabelsAsBitmap bmp ;
2022-09-25 23:29:30 +00:00
for ( auto & [ label , state ] : arcs )
2023-01-10 16:26:27 +00:00
bmp . addLabel ( label ) ;
written_bytes + = bmp . serialize ( write_buffer ) ;
2022-06-24 01:56:15 +00:00
/// Serialize all arcs
2022-09-25 23:29:30 +00:00
for ( auto & [ label , state ] : arcs )
2022-06-24 01:56:15 +00:00
{
2023-01-20 09:32:36 +00:00
Arc * arc = getArc ( label ) ;
2022-06-24 01:56:15 +00:00
assert ( arc ! = nullptr ) ;
written_bytes + = arc - > serialize ( write_buffer ) ;
2022-09-25 23:29:30 +00:00
}
2022-06-24 01:56:15 +00:00
}
return written_bytes ;
}
2023-01-20 10:56:20 +00:00
bool operator = = ( const State & state1 , const State & state2 )
{
if ( state1 . arcs . size ( ) ! = state2 . arcs . size ( ) )
return false ;
for ( const auto & [ label , arc ] : state1 . arcs )
{
const auto it = state2 . arcs . find ( label ) ;
if ( it = = state2 . arcs . end ( ) )
return false ;
if ( it - > second ! = arc )
return false ;
}
return true ;
}
2023-01-20 09:32:36 +00:00
void State : : readFlag ( ReadBuffer & read_buffer )
{
read_buffer . readStrict ( reinterpret_cast < char & > ( flag ) ) ;
}
FstBuilder : : FstBuilder ( WriteBuffer & write_buffer_ ) : write_buffer ( write_buffer_ )
2022-06-24 01:56:15 +00:00
{
2022-07-19 20:15:59 +00:00
for ( auto & temp_state : temp_states )
temp_state = std : : make_shared < State > ( ) ;
2022-06-24 01:56:15 +00:00
}
2023-01-11 21:40:20 +00:00
/// See FindMinimized in the paper pseudo code l11-l21.
2023-01-20 09:32:36 +00:00
StatePtr FstBuilder : : findMinimized ( const State & state , bool & found )
2022-06-24 01:56:15 +00:00
{
found = false ;
auto hash = state . hash ( ) ;
2023-01-11 21:40:20 +00:00
/// MEMBER: in the paper pseudo code l15
2022-09-25 23:29:30 +00:00
auto it = minimized_states . find ( hash ) ;
2022-06-24 01:56:15 +00:00
2023-01-20 09:32:36 +00:00
if ( it ! = minimized_states . end ( ) & & * it - > second = = state )
2022-06-24 01:56:15 +00:00
{
found = true ;
return it - > second ;
}
2023-01-11 21:40:20 +00:00
/// COPY_STATE: in the paper pseudo code l17
2022-09-25 23:29:30 +00:00
StatePtr p = std : : make_shared < State > ( state ) ;
2023-01-11 21:40:20 +00:00
/// INSERT: in the paper pseudo code l18
2022-06-24 01:56:15 +00:00
minimized_states [ hash ] = p ;
return p ;
}
2023-01-20 09:32:36 +00:00
namespace
{
2023-01-11 21:40:20 +00:00
/// See the paper pseudo code l33-34.
2023-01-20 09:32:36 +00:00
size_t getCommonPrefixLength ( std : : string_view word1 , std : : string_view word2 )
2022-06-24 01:56:15 +00:00
{
size_t i = 0 ;
while ( i < word1 . size ( ) & & i < word2 . size ( ) & & word1 [ i ] = = word2 [ i ] )
i + + ;
return i ;
}
2023-01-20 09:32:36 +00:00
}
2023-01-11 21:40:20 +00:00
/// See the paper pseudo code l33-39 and l70-72(when down_to is 0).
2023-01-20 09:32:36 +00:00
void FstBuilder : : minimizePreviousWordSuffix ( Int64 down_to )
2022-06-24 01:56:15 +00:00
{
2022-09-25 23:29:30 +00:00
for ( Int64 i = static_cast < Int64 > ( previous_word . size ( ) ) ; i > = down_to ; - - i )
2022-06-24 01:56:15 +00:00
{
2023-01-10 16:26:27 +00:00
bool found = false ;
auto minimized_state = findMinimized ( * temp_states [ i ] , found ) ;
2022-06-24 01:56:15 +00:00
if ( i ! = 0 )
{
Output output = 0 ;
2023-01-20 09:32:36 +00:00
Arc * arc = temp_states [ i - 1 ] - > getArc ( previous_word [ i - 1 ] ) ;
2022-06-24 01:56:15 +00:00
if ( arc )
output = arc - > output ;
2023-01-11 21:40:20 +00:00
/// SET_TRANSITION
2022-06-24 01:56:15 +00:00
temp_states [ i - 1 ] - > addArc ( previous_word [ i - 1 ] , output , minimized_state ) ;
}
if ( minimized_state - > id = = 0 )
minimized_state - > id = next_id + + ;
if ( i > 0 & & temp_states [ i - 1 ] - > id = = 0 )
temp_states [ i - 1 ] - > id = next_id + + ;
if ( ! found )
2022-07-03 12:18:51 +00:00
{
2022-06-24 01:56:15 +00:00
minimized_state - > state_index = previous_state_index ;
previous_written_bytes = minimized_state - > serialize ( write_buffer ) ;
previous_state_index + = previous_written_bytes ;
}
}
}
2023-01-20 09:32:36 +00:00
void FstBuilder : : add ( std : : string_view current_word , Output current_output )
2022-06-24 01:56:15 +00:00
{
2022-09-25 23:29:30 +00:00
/// We assume word size is no greater than MAX_TERM_LENGTH(256).
/// FSTs without word size limitation would be inefficient and easy to cause memory bloat
/// Note that when using "split" tokenizer, if a granule has tokens which are longer than
/// MAX_TERM_LENGTH, the granule cannot be dropped and will be fully-scanned. It doesn't affect "ngram" tokenizers.
/// Another limitation is that if the query string has tokens which exceed this length
/// it will fallback to default searching when using "split" tokenizers.
2023-01-20 09:32:36 +00:00
size_t current_word_len = current_word . size ( ) ;
2022-06-24 01:56:15 +00:00
if ( current_word_len > MAX_TERM_LENGTH )
2023-01-20 09:32:36 +00:00
throw DB : : Exception ( DB : : ErrorCodes : : BAD_ARGUMENTS , " Cannot build inverted index: The maximum term length is {}, this is exceeded by term {} " , MAX_TERM_LENGTH , current_word_len ) ;
2022-06-24 01:56:15 +00:00
2023-01-10 16:26:27 +00:00
size_t prefix_length_plus1 = getCommonPrefixLength ( current_word , previous_word ) + 1 ;
2022-06-24 01:56:15 +00:00
minimizePreviousWordSuffix ( prefix_length_plus1 ) ;
2023-01-11 21:40:20 +00:00
/// Initialize the tail state, see paper pseudo code l39-43
2023-01-10 16:26:27 +00:00
for ( size_t i = prefix_length_plus1 ; i < = current_word . size ( ) ; + + i )
2022-06-24 01:56:15 +00:00
{
2023-01-11 21:40:20 +00:00
/// CLEAR_STATE: l41
2022-06-24 01:56:15 +00:00
temp_states [ i ] - > clear ( ) ;
2023-01-11 21:40:20 +00:00
/// SET_TRANSITION: l42
2022-09-25 23:29:30 +00:00
temp_states [ i - 1 ] - > addArc ( current_word [ i - 1 ] , 0 , temp_states [ i ] ) ;
2022-06-24 01:56:15 +00:00
}
/// We assume the current word is different with previous word
2023-01-11 21:40:20 +00:00
/// See paper pseudo code l44-47
2022-09-25 23:29:30 +00:00
temp_states [ current_word_len ] - > setFinal ( true ) ;
2023-01-11 21:40:20 +00:00
2022-06-24 01:56:15 +00:00
/// Adjust outputs on the arcs
2023-01-11 21:40:20 +00:00
/// See paper pseudo code l48-63
2023-01-10 16:26:27 +00:00
for ( size_t i = 1 ; i < = prefix_length_plus1 - 1 ; + + i )
2022-06-24 01:56:15 +00:00
{
2023-01-10 16:26:27 +00:00
Arc * arc_ptr = temp_states [ i - 1 ] - > getArc ( current_word [ i - 1 ] ) ;
2022-06-24 01:56:15 +00:00
assert ( arc_ptr ! = nullptr ) ;
2022-09-25 23:29:30 +00:00
Output common_prefix = std : : min ( arc_ptr - > output , current_output ) ;
Output word_suffix = arc_ptr - > output - common_prefix ;
2022-06-24 01:56:15 +00:00
arc_ptr - > output = common_prefix ;
/// For each arc, adjust its output
if ( word_suffix ! = 0 )
{
2023-01-10 16:26:27 +00:00
for ( auto & [ label , arc ] : temp_states [ i ] - > arcs )
2022-06-24 01:56:15 +00:00
arc . output + = word_suffix ;
}
/// Reduce current_output
current_output - = common_prefix ;
}
/// Set last temp state's output
2023-01-11 21:40:20 +00:00
/// paper pseudo code l66-67 (assuming CurrentWord != PreviousWorld)
2022-09-25 23:29:30 +00:00
Arc * arc = temp_states [ prefix_length_plus1 - 1 ] - > getArc ( current_word [ prefix_length_plus1 - 1 ] ) ;
2022-06-24 01:56:15 +00:00
assert ( arc ! = nullptr ) ;
arc - > output = current_output ;
previous_word = current_word ;
}
2023-01-20 09:32:36 +00:00
UInt64 FstBuilder : : build ( )
2022-06-24 01:56:15 +00:00
{
minimizePreviousWordSuffix ( 0 ) ;
/// Save initial state index
previous_state_index - = previous_written_bytes ;
UInt8 length = getLengthOfVarUInt ( previous_state_index ) ;
writeVarUInt ( previous_state_index , write_buffer ) ;
write_buffer . write ( length ) ;
return previous_state_index + previous_written_bytes + length + 1 ;
}
2023-01-20 09:32:36 +00:00
FiniteStateTransducer : : FiniteStateTransducer ( std : : vector < UInt8 > data_ )
: data ( std : : move ( data_ ) )
2022-06-24 01:56:15 +00:00
{
}
2022-09-07 18:22:09 +00:00
void FiniteStateTransducer : : clear ( )
2022-06-24 01:56:15 +00:00
{
data . clear ( ) ;
}
2023-01-20 09:32:36 +00:00
std : : pair < UInt64 , bool > FiniteStateTransducer : : getOutput ( std : : string_view term )
2022-06-24 01:56:15 +00:00
{
2023-01-20 09:32:36 +00:00
std : : pair < UInt64 , bool > result ( 0 , false ) ;
2023-01-10 16:26:27 +00:00
2022-06-24 01:56:15 +00:00
/// Read index of initial state
ReadBufferFromMemory read_buffer ( data . data ( ) , data . size ( ) ) ;
2023-01-20 09:32:36 +00:00
read_buffer . seek ( data . size ( ) - 1 , SEEK_SET ) ;
2022-06-24 01:56:15 +00:00
2023-01-20 09:32:36 +00:00
UInt8 length = 0 ;
read_buffer . readStrict ( reinterpret_cast < char & > ( length ) ) ;
2023-01-10 16:26:27 +00:00
/// FST contains no terms
2022-09-25 23:29:30 +00:00
if ( length = = 0 )
2023-01-20 09:32:36 +00:00
return { 0 , false } ;
2022-06-24 01:56:15 +00:00
read_buffer . seek ( data . size ( ) - 1 - length , SEEK_SET ) ;
2023-01-20 09:32:36 +00:00
UInt64 state_index = 0 ;
2022-06-24 01:56:15 +00:00
readVarUInt ( state_index , read_buffer ) ;
for ( size_t i = 0 ; i < = term . size ( ) ; + + i )
{
2023-01-20 09:32:36 +00:00
UInt64 arc_output = 0 ;
2022-06-24 01:56:15 +00:00
/// Read flag
State temp_state ;
read_buffer . seek ( state_index , SEEK_SET ) ;
2022-09-25 23:29:30 +00:00
temp_state . readFlag ( read_buffer ) ;
2022-06-24 01:56:15 +00:00
if ( i = = term . size ( ) )
{
2023-01-10 16:26:27 +00:00
result . second = temp_state . isFinal ( ) ;
2022-06-24 01:56:15 +00:00
break ;
}
UInt8 label = term [ i ] ;
2023-01-10 16:26:27 +00:00
if ( temp_state . getEncodingMethod ( ) = = State : : EncodingMethod : : Sequential )
2022-06-24 01:56:15 +00:00
{
/// Read number of labels
2023-01-20 09:32:36 +00:00
UInt8 label_num = 0 ;
read_buffer . readStrict ( reinterpret_cast < char & > ( label_num ) ) ;
2022-06-24 01:56:15 +00:00
2022-07-03 12:18:51 +00:00
if ( label_num = = 0 )
2023-01-20 09:32:36 +00:00
return { 0 , false } ;
2022-06-24 01:56:15 +00:00
auto labels_position = read_buffer . getPosition ( ) ;
/// Find the index of the label from "labels" bytes
2023-01-20 09:32:36 +00:00
auto begin_it = data . begin ( ) + labels_position ;
auto end_it = data . begin ( ) + labels_position + label_num ;
2022-06-24 01:56:15 +00:00
auto pos = std : : find ( begin_it , end_it , label ) ;
if ( pos = = end_it )
2023-01-20 09:32:36 +00:00
return { 0 , false } ;
2022-06-24 01:56:15 +00:00
/// Read the arc for the label
UInt64 arc_index = ( pos - begin_it ) ;
auto arcs_start_postion = labels_position + label_num ;
read_buffer . seek ( arcs_start_postion , SEEK_SET ) ;
for ( size_t j = 0 ; j < = arc_index ; j + + )
{
state_index = 0 ;
arc_output = 0 ;
readVarUInt ( state_index , read_buffer ) ;
if ( state_index & 0x1 ) // output is followed
readVarUInt ( arc_output , read_buffer ) ;
state_index > > = 1 ;
}
}
else
{
2023-01-10 16:26:27 +00:00
LabelsAsBitmap bmp ;
2022-06-24 01:56:15 +00:00
readVarUInt ( bmp . data . items [ 0 ] , read_buffer ) ;
readVarUInt ( bmp . data . items [ 1 ] , read_buffer ) ;
readVarUInt ( bmp . data . items [ 2 ] , read_buffer ) ;
readVarUInt ( bmp . data . items [ 3 ] , read_buffer ) ;
2023-01-10 16:26:27 +00:00
if ( ! bmp . hasLabel ( label ) )
2023-01-20 09:32:36 +00:00
return { 0 , false } ;
2022-06-24 01:56:15 +00:00
/// Read the arc for the label
size_t arc_index = bmp . getIndex ( label ) ;
for ( size_t j = 0 ; j < arc_index ; j + + )
{
state_index = 0 ;
arc_output = 0 ;
readVarUInt ( state_index , read_buffer ) ;
if ( state_index & 0x1 ) // output is followed
readVarUInt ( arc_output , read_buffer ) ;
state_index > > = 1 ;
}
}
/// Accumulate the output value
2023-01-10 16:26:27 +00:00
result . first + = arc_output ;
2022-06-24 01:56:15 +00:00
}
2023-01-10 16:26:27 +00:00
return result ;
2022-06-24 01:56:15 +00:00
}
2023-01-20 09:32:36 +00:00
2022-06-24 01:56:15 +00:00
}
2023-01-20 09:32:36 +00:00
2022-09-07 18:22:09 +00:00
}