diff --git a/src/Functions/FunctionStartsEndsWith.h b/src/Functions/FunctionStartsEndsWith.h index 5a3aba62f26..bbe1631fdf9 100644 --- a/src/Functions/FunctionStartsEndsWith.h +++ b/src/Functions/FunctionStartsEndsWith.h @@ -1,4 +1,6 @@ #pragma once +#include + #include #include #include @@ -7,7 +9,9 @@ #include #include #include +#include #include +#include namespace DB { @@ -17,6 +21,7 @@ using namespace GatherUtils; namespace ErrorCodes { extern const int ILLEGAL_COLUMN; + extern const int LOGICAL_ERROR; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } @@ -59,16 +64,65 @@ public: DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { - if (!isStringOrFixedString(arguments[0])) - throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (isStringOrFixedString(arguments[0]) && isStringOrFixedString(arguments[1])) + return std::make_shared(); - if (!isStringOrFixedString(arguments[1])) - throw Exception("Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (isArray(arguments[0]) && isArray(arguments[1])) + return std::make_shared(); - return std::make_shared(); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal types {} {} of arguments of function {}. Both must be String or Array", + arguments[0]->getName(), arguments[1]->getName(), getName()); } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + auto data_type = arguments[0].type; + if (isStringOrFixedString(*data_type)) + return executeImplString(arguments, {}, input_rows_count); + if (isArray(data_type)) + return executeImplArray(arguments, {}, input_rows_count); + return {}; + } + +private: + ColumnPtr executeImplArray(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const + { + DataTypePtr common_type = getLeastSupertype(collections::map(arguments, [](auto & arg) { return arg.type; })); + + Columns preprocessed_columns(2); + for (size_t i = 0; i < 2; ++i) + preprocessed_columns[i] = castColumn(arguments[i], common_type); + + std::vector> sources; + for (auto & argument_column : preprocessed_columns) + { + bool is_const = false; + + if (const auto * argument_column_const = typeid_cast(argument_column.get())) + { + is_const = true; + argument_column = argument_column_const->getDataColumnPtr(); + } + + if (const auto * argument_column_array = typeid_cast(argument_column.get())) + sources.emplace_back(GatherUtils::createArraySource(*argument_column_array, is_const, input_rows_count)); + else + throw Exception{"Arguments for function " + getName() + " must be arrays.", ErrorCodes::LOGICAL_ERROR}; + } + + auto result_column = ColumnUInt8::create(input_rows_count); + auto * result_column_ptr = typeid_cast(result_column.get()); + + if constexpr (std::is_same_v) + GatherUtils::sliceHas(*sources[0], *sources[1], GatherUtils::ArraySearchType::StartsWith, *result_column_ptr); + else + GatherUtils::sliceHas(*sources[0], *sources[1], GatherUtils::ArraySearchType::EndsWith, *result_column_ptr); + + return result_column; + } + + ColumnPtr executeImplString(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const { const IColumn * haystack_column = arguments[0].column.get(); const IColumn * needle_column = arguments[1].column.get(); @@ -92,7 +146,6 @@ public: return col_res; } -private: template void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray & res_data) const { diff --git a/src/Functions/GatherUtils/Algorithms.h b/src/Functions/GatherUtils/Algorithms.h index 046e2dcf70f..2d4544b2167 100644 --- a/src/Functions/GatherUtils/Algorithms.h +++ b/src/Functions/GatherUtils/Algorithms.h @@ -496,6 +496,31 @@ bool sliceHasImplAnyAll(const FirstSliceType & first, const SecondSliceType & se return search_type == ArraySearchType::All; } +template < + ArraySearchType search_type, + typename FirstSliceType, + typename SecondSliceType, + bool (*isEqual)(const FirstSliceType &, const SecondSliceType &, size_t, size_t)> +bool sliceHasImplStartsEndsWith(const FirstSliceType & first, const SecondSliceType & second, const UInt8 * first_null_map, const UInt8 * second_null_map) +{ + const bool has_first_null_map = first_null_map != nullptr; + const bool has_second_null_map = second_null_map != nullptr; + + if (first.size < second.size) + return false; + + size_t first_index = (search_type == ArraySearchType::StartsWith) ? 0 : first.size - second.size; + for (size_t second_index = 0; second_index < second.size; ++second_index, ++first_index) + { + const bool is_first_null = has_first_null_map && first_null_map[first_index]; + const bool is_second_null = has_second_null_map && second_null_map[second_index]; + if (is_first_null != is_second_null) + return false; + if (!is_first_null && !is_second_null && !isEqual(first, second, first_index, second_index)) + return false; + } + return true; +} /// For details of Knuth-Morris-Pratt string matching algorithm see /// https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm. @@ -589,6 +614,8 @@ bool sliceHasImpl(const FirstSliceType & first, const SecondSliceType & second, { if constexpr (search_type == ArraySearchType::Substr) return sliceHasImplSubstr(first, second, first_null_map, second_null_map); + else if constexpr (search_type == ArraySearchType::StartsWith || search_type == ArraySearchType::EndsWith) + return sliceHasImplStartsEndsWith(first, second, first_null_map, second_null_map); else return sliceHasImplAnyAll(first, second, first_null_map, second_null_map); } diff --git a/src/Functions/GatherUtils/GatherUtils.h b/src/Functions/GatherUtils/GatherUtils.h index 8a623caa297..52a01b6ff62 100644 --- a/src/Functions/GatherUtils/GatherUtils.h +++ b/src/Functions/GatherUtils/GatherUtils.h @@ -34,7 +34,9 @@ enum class ArraySearchType { Any, // Corresponds to the hasAny array function All, // Corresponds to the hasAll array function - Substr // Corresponds to the hasSubstr array function + Substr, // Corresponds to the hasSubstr array function + StartsWith, + EndsWith }; std::unique_ptr createArraySource(const ColumnArray & col, bool is_const, size_t total_rows); @@ -58,6 +60,8 @@ ColumnArray::MutablePtr sliceFromRightDynamicLength(IArraySource & src, const IC void sliceHasAny(IArraySource & first, IArraySource & second, ColumnUInt8 & result); void sliceHasAll(IArraySource & first, IArraySource & second, ColumnUInt8 & result); void sliceHasSubstr(IArraySource & first, IArraySource & second, ColumnUInt8 & result); +void sliceHasStartsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result); +void sliceHasEndsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result); inline void sliceHas(IArraySource & first, IArraySource & second, ArraySearchType search_type, ColumnUInt8 & result) { @@ -72,7 +76,12 @@ inline void sliceHas(IArraySource & first, IArraySource & second, ArraySearchTyp case ArraySearchType::Substr: sliceHasSubstr(first, second, result); break; - + case ArraySearchType::StartsWith: + sliceHasStartsWith(first, second, result); + break; + case ArraySearchType::EndsWith: + sliceHasEndsWith(first, second, result); + break; } } diff --git a/src/Functions/GatherUtils/ends_with.cpp b/src/Functions/GatherUtils/ends_with.cpp new file mode 100644 index 00000000000..579d903005a --- /dev/null +++ b/src/Functions/GatherUtils/ends_with.cpp @@ -0,0 +1,71 @@ +#include "GatherUtils.h" +#include "Selectors.h" +#include "Algorithms.h" + +namespace DB::GatherUtils +{ + +namespace +{ + +struct ArrayEndsWithSelectArraySourcePair : public ArraySourcePairSelector +{ + template + static void callFunction(FirstSource && first, + bool is_second_const, bool is_second_nullable, SecondSource && second, + ColumnUInt8 & result) + { + using SourceType = typename std::decay::type; + + if (is_second_nullable) + { + using NullableSource = NullableArraySource; + + if (is_second_const) + arrayAllAny(first, static_cast &>(second), result); + else + arrayAllAny(first, static_cast(second), result); + } + else + { + if (is_second_const) + arrayAllAny(first, static_cast &>(second), result); + else + arrayAllAny(first, second, result); + } + } + + template + static void selectSourcePair(bool is_first_const, bool is_first_nullable, FirstSource && first, + bool is_second_const, bool is_second_nullable, SecondSource && second, + ColumnUInt8 & result) + { + using SourceType = typename std::decay::type; + + if (is_first_nullable) + { + using NullableSource = NullableArraySource; + + if (is_first_const) + callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result); + else + callFunction(static_cast(first), is_second_const, is_second_nullable, second, result); + } + else + { + if (is_first_const) + callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result); + else + callFunction(first, is_second_const, is_second_nullable, second, result); + } + } +}; + +} + +void sliceHasEndsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result) +{ + ArrayEndsWithSelectArraySourcePair::select(first, second, result); +} + +} diff --git a/src/Functions/GatherUtils/starts_with.cpp b/src/Functions/GatherUtils/starts_with.cpp new file mode 100644 index 00000000000..813294bc092 --- /dev/null +++ b/src/Functions/GatherUtils/starts_with.cpp @@ -0,0 +1,71 @@ +#include "GatherUtils.h" +#include "Selectors.h" +#include "Algorithms.h" + +namespace DB::GatherUtils +{ + +namespace +{ + +struct ArrayStartsWithSelectArraySourcePair : public ArraySourcePairSelector +{ + template + static void callFunction(FirstSource && first, + bool is_second_const, bool is_second_nullable, SecondSource && second, + ColumnUInt8 & result) + { + using SourceType = typename std::decay::type; + + if (is_second_nullable) + { + using NullableSource = NullableArraySource; + + if (is_second_const) + arrayAllAny(first, static_cast &>(second), result); + else + arrayAllAny(first, static_cast(second), result); + } + else + { + if (is_second_const) + arrayAllAny(first, static_cast &>(second), result); + else + arrayAllAny(first, second, result); + } + } + + template + static void selectSourcePair(bool is_first_const, bool is_first_nullable, FirstSource && first, + bool is_second_const, bool is_second_nullable, SecondSource && second, + ColumnUInt8 & result) + { + using SourceType = typename std::decay::type; + + if (is_first_nullable) + { + using NullableSource = NullableArraySource; + + if (is_first_const) + callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result); + else + callFunction(static_cast(first), is_second_const, is_second_nullable, second, result); + } + else + { + if (is_first_const) + callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result); + else + callFunction(first, is_second_const, is_second_nullable, second, result); + } + } +}; + +} + +void sliceHasStartsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result) +{ + ArrayStartsWithSelectArraySourcePair::select(first, second, result); +} + +} diff --git a/tests/queries/0_stateless/02206_array_starts_ends_with.reference b/tests/queries/0_stateless/02206_array_starts_ends_with.reference new file mode 100644 index 00000000000..e0dacfc06e0 --- /dev/null +++ b/tests/queries/0_stateless/02206_array_starts_ends_with.reference @@ -0,0 +1,30 @@ +1 +1 +0 +- +1 +1 +0 +1 +0 +- +1 +0 +1 +0 +- +1 +1 +0 +- +1 +1 +0 +1 +0 +- +1 +0 +- +1 +1 diff --git a/tests/queries/0_stateless/02206_array_starts_ends_with.sql b/tests/queries/0_stateless/02206_array_starts_ends_with.sql new file mode 100644 index 00000000000..39b02c29dc0 --- /dev/null +++ b/tests/queries/0_stateless/02206_array_starts_ends_with.sql @@ -0,0 +1,36 @@ +select startsWith([], []); +select startsWith([1], []); +select startsWith([], [1]); +select '-'; + +select startsWith([NULL], [NULL]); +select startsWith([NULL], []); +select startsWith([], [NULL]); +select startsWith([NULL, 1], [NULL]); +select startsWith([NULL, 1], [1]); +select '-'; + +select startsWith([1, 2, 3, 4], [1, 2, 3]); +select startsWith([1, 2, 3, 4], [1, 2, 4]); +select startsWith(['a', 'b', 'c'], ['a', 'b']); +select startsWith(['a', 'b', 'c'], ['b']); +select '-'; + +select endsWith([], []); +select endsWith([1], []); +select endsWith([], [1]); +select '-'; + +select endsWith([NULL], [NULL]); +select endsWith([NULL], []); +select endsWith([], [NULL]); +select endsWith([1, NULL], [NULL]); +select endsWith([NULL, 1], [NULL]); +select '-'; + +select endsWith([1, 2, 3, 4], [3, 4]); +select endsWith([1, 2, 3, 4], [3]); +select '-'; + +select startsWith([1], emptyArrayUInt8()); +select endsWith([1], emptyArrayUInt8());