From 65b9c69c90a5816798b6adf25d64f6eaed98e658 Mon Sep 17 00:00:00 2001 From: ltrk2 <107155950+ltrk2@users.noreply.github.com> Date: Tue, 17 Jan 2023 05:27:41 -0800 Subject: [PATCH] Introduce non-throwing variants of hasToken --- src/Functions/FunctionsStringSearch.h | 67 ++++++++++++++----- src/Functions/HasTokenImpl.h | 21 +++++- src/Functions/hasToken.cpp | 14 ++-- src/Functions/hasTokenCaseInsensitive.cpp | 16 ++--- .../00990_hasToken_and_tokenbf.reference | 5 ++ .../00990_hasToken_and_tokenbf.sql | 10 +++ 6 files changed, 96 insertions(+), 37 deletions(-) diff --git a/src/Functions/FunctionsStringSearch.h b/src/Functions/FunctionsStringSearch.h index d8da525e63a..be5756579bc 100644 --- a/src/Functions/FunctionsStringSearch.h +++ b/src/Functions/FunctionsStringSearch.h @@ -1,10 +1,12 @@ #pragma once #include +#include #include #include #include #include +#include #include #include #include @@ -58,14 +60,31 @@ namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int LOGICAL_ERROR; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } -template +enum class ExecutionErrorPolicy +{ + Null, + Throw +}; + +template class FunctionsStringSearch : public IFunction { public: - static constexpr auto name = Impl::name; + inline static auto name = std::invoke( + []() -> std::string + { + if (execution_error_policy == ExecutionErrorPolicy::Null) + return std::string(Impl::name) + "OrNull"; + else if (execution_error_policy == ExecutionErrorPolicy::Throw) + return Impl::name; + + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unhandled execution error policy"); + }); + static FunctionPtr create(ContextPtr) { return std::make_shared(); } String getName() const override { return name; } @@ -117,7 +136,11 @@ public: arguments[2]->getName(), getName()); } - return std::make_shared>(); + auto return_type = std::make_shared>(); + if constexpr (execution_error_policy == ExecutionErrorPolicy::Null) + return makeNullable(return_type); + + return return_type; } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override @@ -133,14 +156,13 @@ public: const ColumnConst * col_needle_const = typeid_cast(&*column_needle); using ResultType = typename Impl::ResultType; - + auto col_res = ColumnVector::create(); + auto & vec_res = col_res->getData(); if constexpr (!Impl::use_default_implementation_for_constants) { - bool is_col_start_pos_const = column_start_pos == nullptr || isColumnConst(*column_start_pos); if (col_haystack_const && col_needle_const) { - auto col_res = ColumnVector::create(); - typename ColumnVector::Container & vec_res = col_res->getData(); + const auto is_col_start_pos_const = !column_start_pos || isColumnConst(*column_start_pos); vec_res.resize(is_col_start_pos_const ? 1 : column_start_pos->size()); Impl::constantConstant( @@ -156,9 +178,6 @@ public: } } - auto col_res = ColumnVector::create(); - - typename ColumnVector::Container & vec_res = col_res->getData(); vec_res.resize(column_haystack->size()); const ColumnString * col_haystack_vector = checkAndGetColumn(&*column_haystack); @@ -174,12 +193,28 @@ public: column_start_pos, vec_res); else if (col_haystack_vector && col_needle_const) - Impl::vectorConstant( - col_haystack_vector->getChars(), - col_haystack_vector->getOffsets(), - col_needle_const->getValue(), - column_start_pos, - vec_res); + { + if constexpr (execution_error_policy == ExecutionErrorPolicy::Null) + { + auto null_map = ColumnUInt8::create(); + Impl::vectorConstant( + col_haystack_vector->getChars(), + col_haystack_vector->getOffsets(), + col_needle_const->getValue(), + column_start_pos, + vec_res, + null_map.get()); + + return ColumnNullable::create(std::move(col_res), std::move(null_map)); + } + else + Impl::vectorConstant( + col_haystack_vector->getChars(), + col_haystack_vector->getOffsets(), + col_needle_const->getValue(), + column_start_pos, + vec_res); + } else if (col_haystack_vector_fixed && col_needle_vector) Impl::vectorFixedVector( col_haystack_vector_fixed->getChars(), diff --git a/src/Functions/HasTokenImpl.h b/src/Functions/HasTokenImpl.h index 9328bd99139..196349f92a2 100644 --- a/src/Functions/HasTokenImpl.h +++ b/src/Functions/HasTokenImpl.h @@ -31,7 +31,8 @@ struct HasTokenImpl const ColumnString::Offsets & haystack_offsets, const std::string & pattern, const ColumnPtr & start_pos, - PaddedPODArray & res) + PaddedPODArray & res, + ColumnUInt8* null_map = nullptr) { if (start_pos != nullptr) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function '{}' does not support start_pos argument", name); @@ -46,10 +47,24 @@ struct HasTokenImpl /// The current index in the array of strings. size_t i = 0; - TokenSearcher searcher(pattern.data(), pattern.size(), end - pos); + std::optional searcher; + try + { + searcher.emplace(pattern.data(), pattern.size(), end - pos); + if (null_map) + null_map->getData().resize_fill(haystack_offsets.size(), false); + } + catch (...) + { + if (!null_map) + throw; + + null_map->getData().resize_fill(haystack_offsets.size(), true); + return; + } /// We will search for the next occurrence in all rows at once. - while (pos < end && end != (pos = searcher.search(pos, end - pos))) + while (pos < end && end != (pos = searcher->search(pos, end - pos))) { /// Let's determine which index it refers to. while (begin + haystack_offsets[i] <= pos) diff --git a/src/Functions/hasToken.cpp b/src/Functions/hasToken.cpp index 646ff0b54f7..189ef08051c 100644 --- a/src/Functions/hasToken.cpp +++ b/src/Functions/hasToken.cpp @@ -1,26 +1,22 @@ +#include "FunctionFactory.h" #include "FunctionsStringSearch.h" -#include #include "HasTokenImpl.h" + #include - -namespace DB -{ namespace { - struct NameHasToken { static constexpr auto name = "hasToken"; }; -using FunctionHasToken = FunctionsStringSearch>; - +using FunctionHasToken = DB::FunctionsStringSearch>; +using FunctionHasTokenOrNull = DB::FunctionsStringSearch, DB::ExecutionErrorPolicy::Null>; } REGISTER_FUNCTION(HasToken) { factory.registerFunction(); -} - + factory.registerFunction(); } diff --git a/src/Functions/hasTokenCaseInsensitive.cpp b/src/Functions/hasTokenCaseInsensitive.cpp index 0012ea3e148..ad61bfaa419 100644 --- a/src/Functions/hasTokenCaseInsensitive.cpp +++ b/src/Functions/hasTokenCaseInsensitive.cpp @@ -1,27 +1,25 @@ +#include "FunctionFactory.h" #include "FunctionsStringSearch.h" -#include #include "HasTokenImpl.h" + #include - -namespace DB -{ namespace { - struct NameHasTokenCaseInsensitive { static constexpr auto name = "hasTokenCaseInsensitive"; }; using FunctionHasTokenCaseInsensitive - = FunctionsStringSearch>; - + = DB::FunctionsStringSearch>; +using FunctionHasTokenCaseInsensitiveOrNull = DB::FunctionsStringSearch< + DB::HasTokenImpl, + DB::ExecutionErrorPolicy::Null>; } REGISTER_FUNCTION(HasTokenCaseInsensitive) { factory.registerFunction(); -} - + factory.registerFunction(); } diff --git a/tests/queries/0_stateless/00990_hasToken_and_tokenbf.reference b/tests/queries/0_stateless/00990_hasToken_and_tokenbf.reference index 10e8f0d2c59..4b3beccf5f1 100644 --- a/tests/queries/0_stateless/00990_hasToken_and_tokenbf.reference +++ b/tests/queries/0_stateless/00990_hasToken_and_tokenbf.reference @@ -1,3 +1,8 @@ +0 +0 2007 2007 2007 +0 +2007 +2007 diff --git a/tests/queries/0_stateless/00990_hasToken_and_tokenbf.sql b/tests/queries/0_stateless/00990_hasToken_and_tokenbf.sql index ad50420b6ae..bfcda1657d1 100644 --- a/tests/queries/0_stateless/00990_hasToken_and_tokenbf.sql +++ b/tests/queries/0_stateless/00990_hasToken_and_tokenbf.sql @@ -13,9 +13,19 @@ insert into bloom_filter select number+2000, 'abc,def,zzz' from numbers(8); insert into bloom_filter select number+3000, 'yyy,uuu' from numbers(1024); insert into bloom_filter select number+3000, 'abcdefzzz' from numbers(1024); +SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'abc,def,zzz'); -- { serverError BAD_ARGUMENTS } +SELECT max(id) FROM bloom_filter WHERE hasTokenCaseInsensitive(s, 'abc,def,zzz'); -- { serverError BAD_ARGUMENTS } + +SELECT max(id) FROM bloom_filter WHERE hasTokenOrNull(s, 'abc,def,zzz'); +SELECT max(id) FROM bloom_filter WHERE hasTokenCaseInsensitiveOrNull(s, 'abc,def,zzz'); + +select max(id) from bloom_filter where hasTokenCaseInsensitive(s, 'ABC'); +select max(id) from bloom_filter where hasTokenCaseInsensitive(s, 'zZz'); + set max_rows_to_read = 16; SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'abc'); +SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'ABC'); SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'def'); SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'zzz');