2021-07-30 13:30:30 +00:00
# if !defined(ARCADIA_BUILD)
# include "config_core.h"
# endif
# if USE_NLP
2021-06-03 02:20:42 +00:00
# include <Columns/ColumnArray.h>
# include <Columns/ColumnString.h>
# include <Columns/ColumnVector.h>
# include <DataTypes/DataTypeArray.h>
# include <DataTypes/DataTypeString.h>
# include <Functions/FunctionFactory.h>
# include <Functions/FunctionHelpers.h>
# include <Functions/IFunction.h>
# include <Interpreters/Context.h>
2021-06-19 17:57:46 +00:00
# include <Interpreters/SynonymsExtensions.h>
2021-06-03 02:20:42 +00:00
# include <string_view>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN ;
extern const int ILLEGAL_TYPE_OF_ARGUMENT ;
2021-07-30 15:25:51 +00:00
extern const int SUPPORT_IS_DISABLED ;
2021-06-03 02:20:42 +00:00
}
class FunctionSynonyms : public IFunction
{
public :
static constexpr auto name = " synonyms " ;
2021-06-19 18:52:09 +00:00
static FunctionPtr create ( ContextPtr context )
{
2021-07-30 15:25:51 +00:00
if ( ! context - > getSettingsRef ( ) . allow_experimental_nlp_functions )
throw Exception ( ErrorCodes : : SUPPORT_IS_DISABLED , " Natural language processing function '{}' is experimental. Set `allow_experimental_nlp_functions` setting to enable it " , name ) ;
2021-06-03 02:20:42 +00:00
return std : : make_shared < FunctionSynonyms > ( context - > getSynonymsExtensions ( ) ) ;
}
private :
SynonymsExtensions & extensions ;
public :
2021-06-20 12:31:07 +00:00
explicit FunctionSynonyms ( SynonymsExtensions & extensions_ )
2021-06-04 21:52:44 +00:00
: extensions ( extensions_ ) { }
2021-06-03 02:20:42 +00:00
String getName ( ) const override { return name ; }
size_t getNumberOfArguments ( ) const override { return 2 ; }
2021-08-10 11:31:15 +00:00
bool isSuitableForShortCircuitArgumentsExecution ( const DataTypesWithConstInfo & /*arguments*/ ) const override { return true ; }
2021-06-03 02:20:42 +00:00
DataTypePtr getReturnTypeImpl ( const DataTypes & arguments ) const override
{
if ( ! isString ( arguments [ 0 ] ) )
throw Exception (
" Illegal type " + arguments [ 0 ] - > getName ( ) + " of argument of function " + getName ( ) , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
if ( ! isString ( arguments [ 1 ] ) )
throw Exception (
" Illegal type " + arguments [ 1 ] - > getName ( ) + " of argument of function " + getName ( ) , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
return std : : make_shared < DataTypeArray > ( std : : make_shared < DataTypeString > ( ) ) ;
}
bool useDefaultImplementationForConstants ( ) const override { return true ; }
ColumnNumbers getArgumentsThatAreAlwaysConstant ( ) const override { return { 0 } ; }
ColumnPtr executeImpl ( const ColumnsWithTypeAndName & arguments , const DataTypePtr & result_type , size_t input_rows_count ) const override
{
const auto & extcolumn = arguments [ 0 ] . column ;
const auto & strcolumn = arguments [ 1 ] . column ;
const ColumnConst * ext_col = checkAndGetColumn < ColumnConst > ( extcolumn . get ( ) ) ;
const ColumnString * word_col = checkAndGetColumn < ColumnString > ( strcolumn . get ( ) ) ;
if ( ! ext_col )
throw Exception (
" Illegal column " + arguments [ 0 ] . column - > getName ( ) + " of argument of function " + getName ( ) ,
ErrorCodes : : ILLEGAL_COLUMN ) ;
if ( ! word_col )
throw Exception (
" Illegal column " + arguments [ 1 ] . column - > getName ( ) + " of argument of function " + getName ( ) ,
ErrorCodes : : ILLEGAL_COLUMN ) ;
String ext_name = ext_col - > getValue < String > ( ) ;
auto extension = extensions . getExtension ( ext_name ) ;
/// Create and fill the result array.
const DataTypePtr & elem_type = static_cast < const DataTypeArray & > ( * result_type ) . getNestedType ( ) ;
auto out = ColumnArray : : create ( elem_type - > createColumn ( ) ) ;
IColumn & out_data = out - > getData ( ) ;
IColumn : : Offsets & out_offsets = out - > getOffsets ( ) ;
const ColumnString : : Chars & data = word_col - > getChars ( ) ;
const ColumnString : : Offsets & offsets = word_col - > getOffsets ( ) ;
out_data . reserve ( input_rows_count ) ;
out_offsets . resize ( input_rows_count ) ;
IColumn : : Offset current_offset = 0 ;
for ( size_t i = 0 ; i < offsets . size ( ) ; + + i )
{
std : : string_view word ( reinterpret_cast < const char * > ( data . data ( ) + offsets [ i - 1 ] ) , offsets [ i ] - offsets [ i - 1 ] - 1 ) ;
2021-06-20 12:31:07 +00:00
const auto * synset = extension - > getSynonyms ( word ) ;
2021-06-19 18:52:09 +00:00
2021-06-03 02:20:42 +00:00
if ( synset )
{
2021-07-07 11:07:20 +00:00
for ( const auto & token : * synset )
2021-06-03 02:20:42 +00:00
out_data . insert ( Field ( token . data ( ) , token . size ( ) ) ) ;
current_offset + = synset - > size ( ) ;
}
out_offsets [ i ] = current_offset ;
}
return out ;
}
} ;
void registerFunctionSynonyms ( FunctionFactory & factory )
{
factory . registerFunction < FunctionSynonyms > ( FunctionFactory : : CaseInsensitive ) ;
}
}
2021-07-30 13:30:30 +00:00
# endif