Merge pull request #34368 from usurai/sw

Add startsWith & endsWith function for arrays
This commit is contained in:
Nikolay Degterinsky 2022-02-18 08:34:48 +03:00 committed by GitHub
commit c09275f0da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 305 additions and 8 deletions

View File

@ -1,4 +1,6 @@
#pragma once
#include <base/map.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/GatherUtils/GatherUtils.h>
#include <Functions/GatherUtils/Sources.h>
@ -7,7 +9,9 @@
#include <Functions/TargetSpecific.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/getLeastSupertype.h>
#include <Columns/ColumnString.h>
#include <Interpreters/castColumn.h>
namespace DB
{
@ -17,6 +21,7 @@ using namespace GatherUtils;
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
@ -59,16 +64,65 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isStringOrFixedString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (isStringOrFixedString(arguments[0]) && isStringOrFixedString(arguments[1]))
return std::make_shared<DataTypeUInt8>();
if (!isStringOrFixedString(arguments[1]))
throw Exception("Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (isArray(arguments[0]) && isArray(arguments[1]))
return std::make_shared<DataTypeUInt8>();
return std::make_shared<DataTypeUInt8>();
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal types {} {} of arguments of function {}. Both must be String or Array",
arguments[0]->getName(), arguments[1]->getName(), getName());
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
auto data_type = arguments[0].type;
if (isStringOrFixedString(*data_type))
return executeImplString(arguments, {}, input_rows_count);
if (isArray(data_type))
return executeImplArray(arguments, {}, input_rows_count);
return {};
}
private:
ColumnPtr executeImplArray(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
{
DataTypePtr common_type = getLeastSupertype(collections::map(arguments, [](auto & arg) { return arg.type; }));
Columns preprocessed_columns(2);
for (size_t i = 0; i < 2; ++i)
preprocessed_columns[i] = castColumn(arguments[i], common_type);
std::vector<std::unique_ptr<GatherUtils::IArraySource>> sources;
for (auto & argument_column : preprocessed_columns)
{
bool is_const = false;
if (const auto * argument_column_const = typeid_cast<const ColumnConst *>(argument_column.get()))
{
is_const = true;
argument_column = argument_column_const->getDataColumnPtr();
}
if (const auto * argument_column_array = typeid_cast<const ColumnArray *>(argument_column.get()))
sources.emplace_back(GatherUtils::createArraySource(*argument_column_array, is_const, input_rows_count));
else
throw Exception{"Arguments for function " + getName() + " must be arrays.", ErrorCodes::LOGICAL_ERROR};
}
auto result_column = ColumnUInt8::create(input_rows_count);
auto * result_column_ptr = typeid_cast<ColumnUInt8 *>(result_column.get());
if constexpr (std::is_same_v<Name, NameStartsWith>)
GatherUtils::sliceHas(*sources[0], *sources[1], GatherUtils::ArraySearchType::StartsWith, *result_column_ptr);
else
GatherUtils::sliceHas(*sources[0], *sources[1], GatherUtils::ArraySearchType::EndsWith, *result_column_ptr);
return result_column;
}
ColumnPtr executeImplString(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
{
const IColumn * haystack_column = arguments[0].column.get();
const IColumn * needle_column = arguments[1].column.get();
@ -92,7 +146,6 @@ public:
return col_res;
}
private:
template <typename HaystackSource>
void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray<UInt8> & res_data) const
{

View File

@ -496,6 +496,31 @@ bool sliceHasImplAnyAll(const FirstSliceType & first, const SecondSliceType & se
return search_type == ArraySearchType::All;
}
template <
ArraySearchType search_type,
typename FirstSliceType,
typename SecondSliceType,
bool (*isEqual)(const FirstSliceType &, const SecondSliceType &, size_t, size_t)>
bool sliceHasImplStartsEndsWith(const FirstSliceType & first, const SecondSliceType & second, const UInt8 * first_null_map, const UInt8 * second_null_map)
{
const bool has_first_null_map = first_null_map != nullptr;
const bool has_second_null_map = second_null_map != nullptr;
if (first.size < second.size)
return false;
size_t first_index = (search_type == ArraySearchType::StartsWith) ? 0 : first.size - second.size;
for (size_t second_index = 0; second_index < second.size; ++second_index, ++first_index)
{
const bool is_first_null = has_first_null_map && first_null_map[first_index];
const bool is_second_null = has_second_null_map && second_null_map[second_index];
if (is_first_null != is_second_null)
return false;
if (!is_first_null && !is_second_null && !isEqual(first, second, first_index, second_index))
return false;
}
return true;
}
/// For details of Knuth-Morris-Pratt string matching algorithm see
/// https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm.
@ -589,6 +614,8 @@ bool sliceHasImpl(const FirstSliceType & first, const SecondSliceType & second,
{
if constexpr (search_type == ArraySearchType::Substr)
return sliceHasImplSubstr<FirstSliceType, SecondSliceType, isEqual, isEqualSecond>(first, second, first_null_map, second_null_map);
else if constexpr (search_type == ArraySearchType::StartsWith || search_type == ArraySearchType::EndsWith)
return sliceHasImplStartsEndsWith<search_type, FirstSliceType, SecondSliceType, isEqual>(first, second, first_null_map, second_null_map);
else
return sliceHasImplAnyAll<search_type, FirstSliceType, SecondSliceType, isEqual>(first, second, first_null_map, second_null_map);
}

View File

@ -34,7 +34,9 @@ enum class ArraySearchType
{
Any, // Corresponds to the hasAny array function
All, // Corresponds to the hasAll array function
Substr // Corresponds to the hasSubstr array function
Substr, // Corresponds to the hasSubstr array function
StartsWith,
EndsWith
};
std::unique_ptr<IArraySource> createArraySource(const ColumnArray & col, bool is_const, size_t total_rows);
@ -58,6 +60,8 @@ ColumnArray::MutablePtr sliceFromRightDynamicLength(IArraySource & src, const IC
void sliceHasAny(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
void sliceHasAll(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
void sliceHasSubstr(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
void sliceHasStartsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
void sliceHasEndsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
inline void sliceHas(IArraySource & first, IArraySource & second, ArraySearchType search_type, ColumnUInt8 & result)
{
@ -72,7 +76,12 @@ inline void sliceHas(IArraySource & first, IArraySource & second, ArraySearchTyp
case ArraySearchType::Substr:
sliceHasSubstr(first, second, result);
break;
case ArraySearchType::StartsWith:
sliceHasStartsWith(first, second, result);
break;
case ArraySearchType::EndsWith:
sliceHasEndsWith(first, second, result);
break;
}
}

View File

@ -0,0 +1,71 @@
#include "GatherUtils.h"
#include "Selectors.h"
#include "Algorithms.h"
namespace DB::GatherUtils
{
namespace
{
struct ArrayEndsWithSelectArraySourcePair : public ArraySourcePairSelector<ArrayEndsWithSelectArraySourcePair>
{
template <typename FirstSource, typename SecondSource>
static void callFunction(FirstSource && first,
bool is_second_const, bool is_second_nullable, SecondSource && second,
ColumnUInt8 & result)
{
using SourceType = typename std::decay<SecondSource>::type;
if (is_second_nullable)
{
using NullableSource = NullableArraySource<SourceType>;
if (is_second_const)
arrayAllAny<ArraySearchType::EndsWith>(first, static_cast<ConstSource<NullableSource> &>(second), result);
else
arrayAllAny<ArraySearchType::EndsWith>(first, static_cast<NullableSource &>(second), result);
}
else
{
if (is_second_const)
arrayAllAny<ArraySearchType::EndsWith>(first, static_cast<ConstSource<SourceType> &>(second), result);
else
arrayAllAny<ArraySearchType::EndsWith>(first, second, result);
}
}
template <typename FirstSource, typename SecondSource>
static void selectSourcePair(bool is_first_const, bool is_first_nullable, FirstSource && first,
bool is_second_const, bool is_second_nullable, SecondSource && second,
ColumnUInt8 & result)
{
using SourceType = typename std::decay<FirstSource>::type;
if (is_first_nullable)
{
using NullableSource = NullableArraySource<SourceType>;
if (is_first_const)
callFunction(static_cast<ConstSource<NullableSource> &>(first), is_second_const, is_second_nullable, second, result);
else
callFunction(static_cast<NullableSource &>(first), is_second_const, is_second_nullable, second, result);
}
else
{
if (is_first_const)
callFunction(static_cast<ConstSource<SourceType> &>(first), is_second_const, is_second_nullable, second, result);
else
callFunction(first, is_second_const, is_second_nullable, second, result);
}
}
};
}
void sliceHasEndsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result)
{
ArrayEndsWithSelectArraySourcePair::select(first, second, result);
}
}

View File

@ -0,0 +1,71 @@
#include "GatherUtils.h"
#include "Selectors.h"
#include "Algorithms.h"
namespace DB::GatherUtils
{
namespace
{
struct ArrayStartsWithSelectArraySourcePair : public ArraySourcePairSelector<ArrayStartsWithSelectArraySourcePair>
{
template <typename FirstSource, typename SecondSource>
static void callFunction(FirstSource && first,
bool is_second_const, bool is_second_nullable, SecondSource && second,
ColumnUInt8 & result)
{
using SourceType = typename std::decay<SecondSource>::type;
if (is_second_nullable)
{
using NullableSource = NullableArraySource<SourceType>;
if (is_second_const)
arrayAllAny<ArraySearchType::StartsWith>(first, static_cast<ConstSource<NullableSource> &>(second), result);
else
arrayAllAny<ArraySearchType::StartsWith>(first, static_cast<NullableSource &>(second), result);
}
else
{
if (is_second_const)
arrayAllAny<ArraySearchType::StartsWith>(first, static_cast<ConstSource<SourceType> &>(second), result);
else
arrayAllAny<ArraySearchType::StartsWith>(first, second, result);
}
}
template <typename FirstSource, typename SecondSource>
static void selectSourcePair(bool is_first_const, bool is_first_nullable, FirstSource && first,
bool is_second_const, bool is_second_nullable, SecondSource && second,
ColumnUInt8 & result)
{
using SourceType = typename std::decay<FirstSource>::type;
if (is_first_nullable)
{
using NullableSource = NullableArraySource<SourceType>;
if (is_first_const)
callFunction(static_cast<ConstSource<NullableSource> &>(first), is_second_const, is_second_nullable, second, result);
else
callFunction(static_cast<NullableSource &>(first), is_second_const, is_second_nullable, second, result);
}
else
{
if (is_first_const)
callFunction(static_cast<ConstSource<SourceType> &>(first), is_second_const, is_second_nullable, second, result);
else
callFunction(first, is_second_const, is_second_nullable, second, result);
}
}
};
}
void sliceHasStartsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result)
{
ArrayStartsWithSelectArraySourcePair::select(first, second, result);
}
}

View File

@ -0,0 +1,30 @@
1
1
0
-
1
1
0
1
0
-
1
0
1
0
-
1
1
0
-
1
1
0
1
0
-
1
0
-
1
1

View File

@ -0,0 +1,36 @@
select startsWith([], []);
select startsWith([1], []);
select startsWith([], [1]);
select '-';
select startsWith([NULL], [NULL]);
select startsWith([NULL], []);
select startsWith([], [NULL]);
select startsWith([NULL, 1], [NULL]);
select startsWith([NULL, 1], [1]);
select '-';
select startsWith([1, 2, 3, 4], [1, 2, 3]);
select startsWith([1, 2, 3, 4], [1, 2, 4]);
select startsWith(['a', 'b', 'c'], ['a', 'b']);
select startsWith(['a', 'b', 'c'], ['b']);
select '-';
select endsWith([], []);
select endsWith([1], []);
select endsWith([], [1]);
select '-';
select endsWith([NULL], [NULL]);
select endsWith([NULL], []);
select endsWith([], [NULL]);
select endsWith([1, NULL], [NULL]);
select endsWith([NULL, 1], [NULL]);
select '-';
select endsWith([1, 2, 3, 4], [3, 4]);
select endsWith([1, 2, 3, 4], [3]);
select '-';
select startsWith([1], emptyArrayUInt8());
select endsWith([1], emptyArrayUInt8());