2021-07-30 13:30:30 +00:00
# if !defined(ARCADIA_BUILD)
# include "config_core.h"
# endif
# if USE_NLP
2021-06-05 00:52:35 +00:00
# include <Columns/ColumnString.h>
# include <DataTypes/DataTypeString.h>
2021-05-04 12:47:34 +00:00
# include <Functions/FunctionFactory.h>
2021-06-05 00:52:35 +00:00
# include <Functions/FunctionHelpers.h>
# include <Functions/IFunction.h>
# include <Interpreters/Context.h>
2021-06-20 12:31:07 +00:00
# include <Interpreters/Lemmatizers.h>
2021-05-04 12:47:34 +00:00
namespace DB
{
2021-06-05 00:52:35 +00:00
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN ;
extern const int ILLEGAL_TYPE_OF_ARGUMENT ;
2021-07-30 15:25:51 +00:00
extern const int SUPPORT_IS_DISABLED ;
2021-06-05 00:52:35 +00:00
}
2021-05-04 12:47:34 +00:00
namespace
{
2021-06-05 00:52:35 +00:00
struct LemmatizeImpl
2021-05-04 12:47:34 +00:00
{
2021-06-05 00:52:35 +00:00
static void vector (
const ColumnString : : Chars & data ,
const ColumnString : : Offsets & offsets ,
ColumnString : : Chars & res_data ,
ColumnString : : Offsets & res_offsets ,
Lemmatizers : : LemmPtr & lemmatizer )
{
res_data . resize ( data . size ( ) ) ;
res_offsets . assign ( offsets ) ;
UInt64 data_size = 0 ;
for ( UInt64 i = 0 ; i < offsets . size ( ) ; + + i )
{
2021-06-19 18:52:09 +00:00
/// lemmatize() uses the fact the fact that each string ends with '\0'
2021-06-05 00:52:35 +00:00
auto result = lemmatizer - > lemmatize ( reinterpret_cast < const char * > ( data . data ( ) + offsets [ i - 1 ] ) ) ;
size_t new_size = strlen ( result . get ( ) ) + 1 ;
2021-06-19 18:52:09 +00:00
2021-06-05 00:52:35 +00:00
if ( data_size + new_size > res_data . size ( ) )
res_data . resize ( data_size + new_size ) ;
memcpy ( res_data . data ( ) + data_size , reinterpret_cast < const unsigned char * > ( result . get ( ) ) , new_size ) ;
2021-06-19 18:52:09 +00:00
2021-06-05 00:52:35 +00:00
data_size + = new_size ;
res_offsets [ i ] = data_size ;
}
res_data . resize ( data_size ) ;
}
2021-05-04 12:47:34 +00:00
} ;
2021-06-05 00:52:35 +00:00
class FunctionLemmatize : public IFunction
{
public :
static constexpr auto name = " lemmatize " ;
2021-06-19 18:52:09 +00:00
static FunctionPtr create ( ContextPtr context )
{
2021-07-30 15:25:51 +00:00
if ( ! context - > getSettingsRef ( ) . allow_experimental_nlp_functions )
throw Exception ( ErrorCodes : : SUPPORT_IS_DISABLED , " Natural language processing function '{}' is experimental. Set `allow_experimental_nlp_functions` setting to enable it " , name ) ;
2021-06-05 00:52:35 +00:00
return std : : make_shared < FunctionLemmatize > ( context - > getLemmatizers ( ) ) ;
}
private :
Lemmatizers & lemmatizers ;
public :
2021-06-20 12:31:07 +00:00
explicit FunctionLemmatize ( Lemmatizers & lemmatizers_ )
2021-06-05 00:52:35 +00:00
: lemmatizers ( lemmatizers_ ) { }
String getName ( ) const override { return name ; }
size_t getNumberOfArguments ( ) const override { return 2 ; }
DataTypePtr getReturnTypeImpl ( const DataTypes & arguments ) const override
{
if ( ! isString ( arguments [ 0 ] ) )
throw Exception (
" Illegal type " + arguments [ 0 ] - > getName ( ) + " of argument of function " + getName ( ) , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
if ( ! isString ( arguments [ 1 ] ) )
throw Exception (
" Illegal type " + arguments [ 1 ] - > getName ( ) + " of argument of function " + getName ( ) , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
return arguments [ 1 ] ;
}
bool useDefaultImplementationForConstants ( ) const override { return true ; }
ColumnNumbers getArgumentsThatAreAlwaysConstant ( ) const override { return { 0 } ; }
ColumnPtr executeImpl ( const ColumnsWithTypeAndName & arguments , const DataTypePtr & , size_t ) const override
{
const auto & langcolumn = arguments [ 0 ] . column ;
const auto & strcolumn = arguments [ 1 ] . column ;
const ColumnConst * lang_col = checkAndGetColumn < ColumnConst > ( langcolumn . get ( ) ) ;
const ColumnString * words_col = checkAndGetColumn < ColumnString > ( strcolumn . get ( ) ) ;
2021-06-19 18:52:09 +00:00
if ( ! lang_col )
2021-06-05 00:52:35 +00:00
throw Exception (
" Illegal column " + arguments [ 0 ] . column - > getName ( ) + " of argument of function " + getName ( ) , ErrorCodes : : ILLEGAL_COLUMN ) ;
if ( ! words_col )
throw Exception (
" Illegal column " + arguments [ 1 ] . column - > getName ( ) + " of argument of function " + getName ( ) , ErrorCodes : : ILLEGAL_COLUMN ) ;
String language = lang_col - > getValue < String > ( ) ;
auto lemmatizer = lemmatizers . getLemmatizer ( language ) ;
auto col_res = ColumnString : : create ( ) ;
LemmatizeImpl : : vector ( words_col - > getChars ( ) , words_col - > getOffsets ( ) , col_res - > getChars ( ) , col_res - > getOffsets ( ) , lemmatizer ) ;
return col_res ;
}
} ;
2021-05-04 12:47:34 +00:00
}
void registerFunctionLemmatize ( FunctionFactory & factory )
{
2021-06-05 00:52:35 +00:00
factory . registerFunction < FunctionLemmatize > ( FunctionFactory : : CaseInsensitive ) ;
2021-05-04 12:47:34 +00:00
}
}
2021-07-30 13:30:30 +00:00
# endif