Introduce non-throwing variants of hasToken

This commit is contained in:
ltrk2 2023-01-17 05:27:41 -08:00
parent 136e4ec1b3
commit 65b9c69c90
6 changed files with 96 additions and 37 deletions

View File

@ -1,10 +1,12 @@
#pragma once
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
@ -58,14 +60,31 @@ namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int LOGICAL_ERROR;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
template <typename Impl>
enum class ExecutionErrorPolicy
{
Null,
Throw
};
template <typename Impl, ExecutionErrorPolicy execution_error_policy = ExecutionErrorPolicy::Throw>
class FunctionsStringSearch : public IFunction
{
public:
static constexpr auto name = Impl::name;
inline static auto name = std::invoke(
[]() -> std::string
{
if (execution_error_policy == ExecutionErrorPolicy::Null)
return std::string(Impl::name) + "OrNull";
else if (execution_error_policy == ExecutionErrorPolicy::Throw)
return Impl::name;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unhandled execution error policy");
});
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionsStringSearch>(); }
String getName() const override { return name; }
@ -117,7 +136,11 @@ public:
arguments[2]->getName(), getName());
}
return std::make_shared<DataTypeNumber<typename Impl::ResultType>>();
auto return_type = std::make_shared<DataTypeNumber<typename Impl::ResultType>>();
if constexpr (execution_error_policy == ExecutionErrorPolicy::Null)
return makeNullable(return_type);
return return_type;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override
@ -133,14 +156,13 @@ public:
const ColumnConst * col_needle_const = typeid_cast<const ColumnConst *>(&*column_needle);
using ResultType = typename Impl::ResultType;
auto col_res = ColumnVector<ResultType>::create();
auto & vec_res = col_res->getData();
if constexpr (!Impl::use_default_implementation_for_constants)
{
bool is_col_start_pos_const = column_start_pos == nullptr || isColumnConst(*column_start_pos);
if (col_haystack_const && col_needle_const)
{
auto col_res = ColumnVector<ResultType>::create();
typename ColumnVector<ResultType>::Container & vec_res = col_res->getData();
const auto is_col_start_pos_const = !column_start_pos || isColumnConst(*column_start_pos);
vec_res.resize(is_col_start_pos_const ? 1 : column_start_pos->size());
Impl::constantConstant(
@ -156,9 +178,6 @@ public:
}
}
auto col_res = ColumnVector<ResultType>::create();
typename ColumnVector<ResultType>::Container & vec_res = col_res->getData();
vec_res.resize(column_haystack->size());
const ColumnString * col_haystack_vector = checkAndGetColumn<ColumnString>(&*column_haystack);
@ -174,12 +193,28 @@ public:
column_start_pos,
vec_res);
else if (col_haystack_vector && col_needle_const)
Impl::vectorConstant(
col_haystack_vector->getChars(),
col_haystack_vector->getOffsets(),
col_needle_const->getValue<String>(),
column_start_pos,
vec_res);
{
if constexpr (execution_error_policy == ExecutionErrorPolicy::Null)
{
auto null_map = ColumnUInt8::create();
Impl::vectorConstant(
col_haystack_vector->getChars(),
col_haystack_vector->getOffsets(),
col_needle_const->getValue<String>(),
column_start_pos,
vec_res,
null_map.get());
return ColumnNullable::create(std::move(col_res), std::move(null_map));
}
else
Impl::vectorConstant(
col_haystack_vector->getChars(),
col_haystack_vector->getOffsets(),
col_needle_const->getValue<String>(),
column_start_pos,
vec_res);
}
else if (col_haystack_vector_fixed && col_needle_vector)
Impl::vectorFixedVector(
col_haystack_vector_fixed->getChars(),

View File

@ -31,7 +31,8 @@ struct HasTokenImpl
const ColumnString::Offsets & haystack_offsets,
const std::string & pattern,
const ColumnPtr & start_pos,
PaddedPODArray<UInt8> & res)
PaddedPODArray<UInt8> & res,
ColumnUInt8* null_map = nullptr)
{
if (start_pos != nullptr)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function '{}' does not support start_pos argument", name);
@ -46,10 +47,24 @@ struct HasTokenImpl
/// The current index in the array of strings.
size_t i = 0;
TokenSearcher searcher(pattern.data(), pattern.size(), end - pos);
std::optional<TokenSearcher> searcher;
try
{
searcher.emplace(pattern.data(), pattern.size(), end - pos);
if (null_map)
null_map->getData().resize_fill(haystack_offsets.size(), false);
}
catch (...)
{
if (!null_map)
throw;
null_map->getData().resize_fill(haystack_offsets.size(), true);
return;
}
/// We will search for the next occurrence in all rows at once.
while (pos < end && end != (pos = searcher.search(pos, end - pos)))
while (pos < end && end != (pos = searcher->search(pos, end - pos)))
{
/// Let's determine which index it refers to.
while (begin + haystack_offsets[i] <= pos)

View File

@ -1,26 +1,22 @@
#include "FunctionFactory.h"
#include "FunctionsStringSearch.h"
#include <Functions/FunctionFactory.h>
#include "HasTokenImpl.h"
#include <Common/Volnitsky.h>
namespace DB
{
namespace
{
struct NameHasToken
{
static constexpr auto name = "hasToken";
};
using FunctionHasToken = FunctionsStringSearch<HasTokenImpl<NameHasToken, VolnitskyCaseSensitiveToken, false>>;
using FunctionHasToken = DB::FunctionsStringSearch<DB::HasTokenImpl<NameHasToken, DB::VolnitskyCaseSensitiveToken, false>>;
using FunctionHasTokenOrNull = DB::FunctionsStringSearch<DB::HasTokenImpl<NameHasToken, DB::VolnitskyCaseSensitiveToken, false>, DB::ExecutionErrorPolicy::Null>;
}
REGISTER_FUNCTION(HasToken)
{
factory.registerFunction<FunctionHasToken>();
}
factory.registerFunction<FunctionHasTokenOrNull>();
}

View File

@ -1,27 +1,25 @@
#include "FunctionFactory.h"
#include "FunctionsStringSearch.h"
#include <Functions/FunctionFactory.h>
#include "HasTokenImpl.h"
#include <Common/Volnitsky.h>
namespace DB
{
namespace
{
struct NameHasTokenCaseInsensitive
{
static constexpr auto name = "hasTokenCaseInsensitive";
};
using FunctionHasTokenCaseInsensitive
= FunctionsStringSearch<HasTokenImpl<NameHasTokenCaseInsensitive, VolnitskyCaseInsensitiveToken, false>>;
= DB::FunctionsStringSearch<DB::HasTokenImpl<NameHasTokenCaseInsensitive, DB::VolnitskyCaseInsensitiveToken, false>>;
using FunctionHasTokenCaseInsensitiveOrNull = DB::FunctionsStringSearch<
DB::HasTokenImpl<NameHasTokenCaseInsensitive, DB::VolnitskyCaseInsensitiveToken, false>,
DB::ExecutionErrorPolicy::Null>;
}
REGISTER_FUNCTION(HasTokenCaseInsensitive)
{
factory.registerFunction<FunctionHasTokenCaseInsensitive>();
}
factory.registerFunction<FunctionHasTokenCaseInsensitiveOrNull>();
}

View File

@ -1,3 +1,8 @@
0
0
2007
2007
2007
0
2007
2007

View File

@ -13,9 +13,19 @@ insert into bloom_filter select number+2000, 'abc,def,zzz' from numbers(8);
insert into bloom_filter select number+3000, 'yyy,uuu' from numbers(1024);
insert into bloom_filter select number+3000, 'abcdefzzz' from numbers(1024);
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'abc,def,zzz'); -- { serverError BAD_ARGUMENTS }
SELECT max(id) FROM bloom_filter WHERE hasTokenCaseInsensitive(s, 'abc,def,zzz'); -- { serverError BAD_ARGUMENTS }
SELECT max(id) FROM bloom_filter WHERE hasTokenOrNull(s, 'abc,def,zzz');
SELECT max(id) FROM bloom_filter WHERE hasTokenCaseInsensitiveOrNull(s, 'abc,def,zzz');
select max(id) from bloom_filter where hasTokenCaseInsensitive(s, 'ABC');
select max(id) from bloom_filter where hasTokenCaseInsensitive(s, 'zZz');
set max_rows_to_read = 16;
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'abc');
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'ABC');
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'def');
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'zzz');