diff --git a/src/Functions/ngrams.cpp b/src/Functions/ngrams.cpp deleted file mode 100644 index c5ce65537cb..00000000000 --- a/src/Functions/ngrams.cpp +++ /dev/null @@ -1,126 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace DB -{ - -namespace ErrorCodes -{ - extern const int BAD_ARGUMENTS; -} - -class FunctionNgrams : public IFunction -{ -public: - - static constexpr auto name = "ngrams"; - - static FunctionPtr create(ContextPtr) - { - return std::make_shared(); - } - - String getName() const override { return name; } - - size_t getNumberOfArguments() const override { return 2; } - bool isVariadic() const override { return false; } - ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } - - bool useDefaultImplementationForNulls() const override { return true; } - bool useDefaultImplementationForConstants() const override { return true; } - bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - - DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override - { - auto ngram_input_argument_type = WhichDataType(arguments[0].type); - if (!ngram_input_argument_type.isStringOrFixedString()) - throw Exception(ErrorCodes::BAD_ARGUMENTS, - "Function {} second argument type should be String or FixedString. Actual {}", - getName(), - arguments[0].type->getName()); - - const auto & column_with_type = arguments[1]; - const auto & ngram_argument_column = arguments[1].column; - auto ngram_argument_type = WhichDataType(column_with_type.type); - - if (!ngram_argument_type.isNativeUInt() || !ngram_argument_column || !isColumnConst(*ngram_argument_column)) - throw Exception(ErrorCodes::BAD_ARGUMENTS, - "Function {} second argument type should be constant UInt. Actual {}", - getName(), - arguments[1].type->getName()); - - Field ngram_argument_value; - ngram_argument_column->get(0, ngram_argument_value); - auto ngram_value = ngram_argument_value.safeGet(); - - return std::make_shared(std::make_shared(ngram_value)); - } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override - { - Field ngram_argument_value; - arguments[1].column->get(0, ngram_argument_value); - auto ngram_value = ngram_argument_value.safeGet(); - - NgramTokenExtractor extractor(ngram_value); - - auto result_column_fixed_string = ColumnFixedString::create(ngram_value); - auto column_offsets = ColumnArray::ColumnOffsets::create(); - - auto input_column = arguments[0].column; - if (const auto * column_string = checkAndGetColumn(input_column.get())) - executeImpl(extractor, *column_string, *result_column_fixed_string, *column_offsets); - else if (const auto * column_fixed_string = checkAndGetColumn(input_column.get())) - executeImpl(extractor, *column_fixed_string, *result_column_fixed_string, *column_offsets); - - return ColumnArray::create(std::move(result_column_fixed_string), std::move(column_offsets)); - } - -private: - - template - inline void executeImpl(const NgramTokenExtractor & extractor, StringColumnType & input_data_column, ColumnFixedString & result_data_column, ColumnArray::ColumnOffsets & offsets_column) const - { - size_t current_tokens_size = 0; - auto & offsets_data = offsets_column.getData(); - - size_t column_size = input_data_column.size(); - offsets_data.resize(column_size); - - for (size_t i = 0; i < column_size; ++i) - { - auto data = input_data_column.getDataAt(i); - - size_t cur = 0; - size_t token_start = 0; - size_t token_length = 0; - - while (cur < data.size && extractor.nextInString(data.data, data.size, &cur, &token_start, &token_length)) - { - result_data_column.insertData(data.data + token_start, token_length); - ++current_tokens_size; - } - - offsets_data[i] = current_tokens_size; - } - } -}; - -void registerFunctionNgrams(FunctionFactory & factory) -{ - factory.registerFunction(); -} - -} - - diff --git a/src/Functions/registerFunctions.cpp b/src/Functions/registerFunctions.cpp index b2f038240aa..94dfcdb4fda 100644 --- a/src/Functions/registerFunctions.cpp +++ b/src/Functions/registerFunctions.cpp @@ -37,7 +37,7 @@ void registerFunctionsStringArray(FunctionFactory &); void registerFunctionsStringSearch(FunctionFactory &); void registerFunctionsStringRegexp(FunctionFactory &); void registerFunctionsStringSimilarity(FunctionFactory &); -void registerFunctionNgrams(FunctionFactory &); +void registerFunctionsStringTokenExtractor(FunctionFactory &); void registerFunctionsURL(FunctionFactory &); void registerFunctionsVisitParam(FunctionFactory &); void registerFunctionsMath(FunctionFactory &); @@ -101,7 +101,7 @@ void registerFunctions() registerFunctionsStringSearch(factory); registerFunctionsStringRegexp(factory); registerFunctionsStringSimilarity(factory); - registerFunctionNgrams(factory); + registerFunctionsStringTokenExtractor(factory); registerFunctionsURL(factory); registerFunctionsVisitParam(factory); registerFunctionsMath(factory); diff --git a/src/Functions/tokenExtractors.cpp b/src/Functions/tokenExtractors.cpp new file mode 100644 index 00000000000..f15cfb4cb62 --- /dev/null +++ b/src/Functions/tokenExtractors.cpp @@ -0,0 +1,166 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + +enum TokenExtractorStrategy +{ + ngrams, + tokens +}; + +template +class FunctionTokenExtractor : public IFunction +{ +public: + + static constexpr auto name = strategy == ngrams ? "ngrams" : "tokens"; + + static FunctionPtr create(ContextPtr) + { + return std::make_shared(); + } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return strategy == ngrams ? 2 : 1; } + bool isVariadic() const override { return false; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return strategy == ngrams ? ColumnNumbers{1} : ColumnNumbers{}; } + + bool useDefaultImplementationForNulls() const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } + bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + auto ngram_input_argument_type = WhichDataType(arguments[0].type); + if (!ngram_input_argument_type.isStringOrFixedString()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Function {} second argument type should be String or FixedString. Actual {}", + getName(), + arguments[0].type->getName()); + + if constexpr (strategy == ngrams) + { + const auto & column_with_type = arguments[1]; + const auto & ngram_argument_column = arguments[1].column; + auto ngram_argument_type = WhichDataType(column_with_type.type); + + if (!ngram_argument_type.isNativeUInt() || !ngram_argument_column || !isColumnConst(*ngram_argument_column)) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Function {} second argument type should be constant UInt. Actual {}", + getName(), + arguments[1].type->getName()); + + Field ngram_argument_value; + ngram_argument_column->get(0, ngram_argument_value); + auto ngram_value = ngram_argument_value.safeGet(); + + return std::make_shared(std::make_shared(ngram_value)); + } + else + { + return std::make_shared(std::make_shared()); + } + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override + { + auto column_offsets = ColumnArray::ColumnOffsets::create(); + + if constexpr (strategy == TokenExtractorStrategy::ngrams) + { + Field ngram_argument_value; + arguments[1].column->get(0, ngram_argument_value); + auto ngram_value = ngram_argument_value.safeGet(); + + NgramTokenExtractor extractor(ngram_value); + + auto result_column_fixed_string = ColumnFixedString::create(ngram_value); + + auto input_column = arguments[0].column; + + if (const auto * column_string = checkAndGetColumn(input_column.get())) + executeImpl(extractor, *column_string, *result_column_fixed_string, *column_offsets); + else if (const auto * column_fixed_string = checkAndGetColumn(input_column.get())) + executeImpl(extractor, *column_fixed_string, *result_column_fixed_string, *column_offsets); + + return ColumnArray::create(std::move(result_column_fixed_string), std::move(column_offsets)); + } + else + { + SplitTokenExtractor extractor; + + auto result_column_string = ColumnString::create(); + + auto input_column = arguments[0].column; + + if (const auto * column_string = checkAndGetColumn(input_column.get())) + executeImpl(extractor, *column_string, *result_column_string, *column_offsets); + else if (const auto * column_fixed_string = checkAndGetColumn(input_column.get())) + executeImpl(extractor, *column_fixed_string, *result_column_string, *column_offsets); + + return ColumnArray::create(std::move(result_column_string), std::move(column_offsets)); + } + } + +private: + + template + inline void executeImpl( + const ExtractorType & extractor, + StringColumnType & input_data_column, + ResultStringColumnType & result_data_column, + ColumnArray::ColumnOffsets & offsets_column) const + { + size_t current_tokens_size = 0; + auto & offsets_data = offsets_column.getData(); + + size_t column_size = input_data_column.size(); + offsets_data.resize(column_size); + + for (size_t i = 0; i < column_size; ++i) + { + auto data = input_data_column.getDataAt(i); + + size_t cur = 0; + size_t token_start = 0; + size_t token_length = 0; + + while (cur < data.size && extractor.nextInString(data.data, data.size, &cur, &token_start, &token_length)) + { + result_data_column.insertData(data.data + token_start, token_length); + ++current_tokens_size; + } + + offsets_data[i] = current_tokens_size; + } + } +}; + +void registerFunctionsStringTokenExtractor(FunctionFactory & factory) +{ + factory.registerFunction>(); + factory.registerFunction>(); +} + +} + + diff --git a/tests/queries/0_stateless/2028_tokens.reference b/tests/queries/0_stateless/2028_tokens.reference new file mode 100644 index 00000000000..0c23a7598b6 --- /dev/null +++ b/tests/queries/0_stateless/2028_tokens.reference @@ -0,0 +1,8 @@ +['test'] +['test1','test2','test3'] +['test1','test2','test3','test4'] +['test1','test2','test3','test4'] +['test'] +['test1','test2','test3'] +['test1','test2','test3','test4'] +['test1','test2','test3','test4'] diff --git a/tests/queries/0_stateless/2028_tokens.sql b/tests/queries/0_stateless/2028_tokens.sql new file mode 100644 index 00000000000..835ec140302 --- /dev/null +++ b/tests/queries/0_stateless/2028_tokens.sql @@ -0,0 +1,9 @@ +SELECT tokens('test'); +SELECT tokens('test1, test2, test3'); +SELECT tokens('test1, test2, test3, test4'); +SELECT tokens('test1,;\ test2,;\ test3,;\ test4'); + +SELECT tokens(materialize('test')); +SELECT tokens(materialize('test1, test2, test3')); +SELECT tokens(materialize('test1, test2, test3, test4')); +SELECT tokens(materialize('test1,;\ test2,;\ test3,;\ test4'));