diff --git a/src/Functions/FunctionStartsEndsWith.h b/src/Functions/FunctionStartsEndsWith.h
index 5a3aba62f26..bbe1631fdf9 100644
--- a/src/Functions/FunctionStartsEndsWith.h
+++ b/src/Functions/FunctionStartsEndsWith.h
@@ -1,4 +1,6 @@
#pragma once
+#include
+
#include
#include
#include
@@ -7,7 +9,9 @@
#include
#include
#include
+#include
#include
+#include
namespace DB
{
@@ -17,6 +21,7 @@ using namespace GatherUtils;
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
+ extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
@@ -59,16 +64,65 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
- if (!isStringOrFixedString(arguments[0]))
- throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
+ if (isStringOrFixedString(arguments[0]) && isStringOrFixedString(arguments[1]))
+ return std::make_shared();
- if (!isStringOrFixedString(arguments[1]))
- throw Exception("Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
+ if (isArray(arguments[0]) && isArray(arguments[1]))
+ return std::make_shared();
- return std::make_shared();
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal types {} {} of arguments of function {}. Both must be String or Array",
+ arguments[0]->getName(), arguments[1]->getName(), getName());
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
+ {
+ auto data_type = arguments[0].type;
+ if (isStringOrFixedString(*data_type))
+ return executeImplString(arguments, {}, input_rows_count);
+ if (isArray(data_type))
+ return executeImplArray(arguments, {}, input_rows_count);
+ return {};
+ }
+
+private:
+ ColumnPtr executeImplArray(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
+ {
+ DataTypePtr common_type = getLeastSupertype(collections::map(arguments, [](auto & arg) { return arg.type; }));
+
+ Columns preprocessed_columns(2);
+ for (size_t i = 0; i < 2; ++i)
+ preprocessed_columns[i] = castColumn(arguments[i], common_type);
+
+ std::vector> sources;
+ for (auto & argument_column : preprocessed_columns)
+ {
+ bool is_const = false;
+
+ if (const auto * argument_column_const = typeid_cast(argument_column.get()))
+ {
+ is_const = true;
+ argument_column = argument_column_const->getDataColumnPtr();
+ }
+
+ if (const auto * argument_column_array = typeid_cast(argument_column.get()))
+ sources.emplace_back(GatherUtils::createArraySource(*argument_column_array, is_const, input_rows_count));
+ else
+ throw Exception{"Arguments for function " + getName() + " must be arrays.", ErrorCodes::LOGICAL_ERROR};
+ }
+
+ auto result_column = ColumnUInt8::create(input_rows_count);
+ auto * result_column_ptr = typeid_cast(result_column.get());
+
+ if constexpr (std::is_same_v)
+ GatherUtils::sliceHas(*sources[0], *sources[1], GatherUtils::ArraySearchType::StartsWith, *result_column_ptr);
+ else
+ GatherUtils::sliceHas(*sources[0], *sources[1], GatherUtils::ArraySearchType::EndsWith, *result_column_ptr);
+
+ return result_column;
+ }
+
+ ColumnPtr executeImplString(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
{
const IColumn * haystack_column = arguments[0].column.get();
const IColumn * needle_column = arguments[1].column.get();
@@ -92,7 +146,6 @@ public:
return col_res;
}
-private:
template
void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray & res_data) const
{
diff --git a/src/Functions/GatherUtils/Algorithms.h b/src/Functions/GatherUtils/Algorithms.h
index 046e2dcf70f..2d4544b2167 100644
--- a/src/Functions/GatherUtils/Algorithms.h
+++ b/src/Functions/GatherUtils/Algorithms.h
@@ -496,6 +496,31 @@ bool sliceHasImplAnyAll(const FirstSliceType & first, const SecondSliceType & se
return search_type == ArraySearchType::All;
}
+template <
+ ArraySearchType search_type,
+ typename FirstSliceType,
+ typename SecondSliceType,
+ bool (*isEqual)(const FirstSliceType &, const SecondSliceType &, size_t, size_t)>
+bool sliceHasImplStartsEndsWith(const FirstSliceType & first, const SecondSliceType & second, const UInt8 * first_null_map, const UInt8 * second_null_map)
+{
+ const bool has_first_null_map = first_null_map != nullptr;
+ const bool has_second_null_map = second_null_map != nullptr;
+
+ if (first.size < second.size)
+ return false;
+
+ size_t first_index = (search_type == ArraySearchType::StartsWith) ? 0 : first.size - second.size;
+ for (size_t second_index = 0; second_index < second.size; ++second_index, ++first_index)
+ {
+ const bool is_first_null = has_first_null_map && first_null_map[first_index];
+ const bool is_second_null = has_second_null_map && second_null_map[second_index];
+ if (is_first_null != is_second_null)
+ return false;
+ if (!is_first_null && !is_second_null && !isEqual(first, second, first_index, second_index))
+ return false;
+ }
+ return true;
+}
/// For details of Knuth-Morris-Pratt string matching algorithm see
/// https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm.
@@ -589,6 +614,8 @@ bool sliceHasImpl(const FirstSliceType & first, const SecondSliceType & second,
{
if constexpr (search_type == ArraySearchType::Substr)
return sliceHasImplSubstr(first, second, first_null_map, second_null_map);
+ else if constexpr (search_type == ArraySearchType::StartsWith || search_type == ArraySearchType::EndsWith)
+ return sliceHasImplStartsEndsWith(first, second, first_null_map, second_null_map);
else
return sliceHasImplAnyAll(first, second, first_null_map, second_null_map);
}
diff --git a/src/Functions/GatherUtils/GatherUtils.h b/src/Functions/GatherUtils/GatherUtils.h
index 8a623caa297..52a01b6ff62 100644
--- a/src/Functions/GatherUtils/GatherUtils.h
+++ b/src/Functions/GatherUtils/GatherUtils.h
@@ -34,7 +34,9 @@ enum class ArraySearchType
{
Any, // Corresponds to the hasAny array function
All, // Corresponds to the hasAll array function
- Substr // Corresponds to the hasSubstr array function
+ Substr, // Corresponds to the hasSubstr array function
+ StartsWith,
+ EndsWith
};
std::unique_ptr createArraySource(const ColumnArray & col, bool is_const, size_t total_rows);
@@ -58,6 +60,8 @@ ColumnArray::MutablePtr sliceFromRightDynamicLength(IArraySource & src, const IC
void sliceHasAny(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
void sliceHasAll(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
void sliceHasSubstr(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
+void sliceHasStartsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
+void sliceHasEndsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result);
inline void sliceHas(IArraySource & first, IArraySource & second, ArraySearchType search_type, ColumnUInt8 & result)
{
@@ -72,7 +76,12 @@ inline void sliceHas(IArraySource & first, IArraySource & second, ArraySearchTyp
case ArraySearchType::Substr:
sliceHasSubstr(first, second, result);
break;
-
+ case ArraySearchType::StartsWith:
+ sliceHasStartsWith(first, second, result);
+ break;
+ case ArraySearchType::EndsWith:
+ sliceHasEndsWith(first, second, result);
+ break;
}
}
diff --git a/src/Functions/GatherUtils/ends_with.cpp b/src/Functions/GatherUtils/ends_with.cpp
new file mode 100644
index 00000000000..579d903005a
--- /dev/null
+++ b/src/Functions/GatherUtils/ends_with.cpp
@@ -0,0 +1,71 @@
+#include "GatherUtils.h"
+#include "Selectors.h"
+#include "Algorithms.h"
+
+namespace DB::GatherUtils
+{
+
+namespace
+{
+
+struct ArrayEndsWithSelectArraySourcePair : public ArraySourcePairSelector
+{
+ template
+ static void callFunction(FirstSource && first,
+ bool is_second_const, bool is_second_nullable, SecondSource && second,
+ ColumnUInt8 & result)
+ {
+ using SourceType = typename std::decay::type;
+
+ if (is_second_nullable)
+ {
+ using NullableSource = NullableArraySource;
+
+ if (is_second_const)
+ arrayAllAny(first, static_cast &>(second), result);
+ else
+ arrayAllAny(first, static_cast(second), result);
+ }
+ else
+ {
+ if (is_second_const)
+ arrayAllAny(first, static_cast &>(second), result);
+ else
+ arrayAllAny(first, second, result);
+ }
+ }
+
+ template
+ static void selectSourcePair(bool is_first_const, bool is_first_nullable, FirstSource && first,
+ bool is_second_const, bool is_second_nullable, SecondSource && second,
+ ColumnUInt8 & result)
+ {
+ using SourceType = typename std::decay::type;
+
+ if (is_first_nullable)
+ {
+ using NullableSource = NullableArraySource;
+
+ if (is_first_const)
+ callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result);
+ else
+ callFunction(static_cast(first), is_second_const, is_second_nullable, second, result);
+ }
+ else
+ {
+ if (is_first_const)
+ callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result);
+ else
+ callFunction(first, is_second_const, is_second_nullable, second, result);
+ }
+ }
+};
+
+}
+
+void sliceHasEndsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result)
+{
+ ArrayEndsWithSelectArraySourcePair::select(first, second, result);
+}
+
+}
diff --git a/src/Functions/GatherUtils/starts_with.cpp b/src/Functions/GatherUtils/starts_with.cpp
new file mode 100644
index 00000000000..813294bc092
--- /dev/null
+++ b/src/Functions/GatherUtils/starts_with.cpp
@@ -0,0 +1,71 @@
+#include "GatherUtils.h"
+#include "Selectors.h"
+#include "Algorithms.h"
+
+namespace DB::GatherUtils
+{
+
+namespace
+{
+
+struct ArrayStartsWithSelectArraySourcePair : public ArraySourcePairSelector
+{
+ template
+ static void callFunction(FirstSource && first,
+ bool is_second_const, bool is_second_nullable, SecondSource && second,
+ ColumnUInt8 & result)
+ {
+ using SourceType = typename std::decay::type;
+
+ if (is_second_nullable)
+ {
+ using NullableSource = NullableArraySource;
+
+ if (is_second_const)
+ arrayAllAny(first, static_cast &>(second), result);
+ else
+ arrayAllAny(first, static_cast(second), result);
+ }
+ else
+ {
+ if (is_second_const)
+ arrayAllAny(first, static_cast &>(second), result);
+ else
+ arrayAllAny(first, second, result);
+ }
+ }
+
+ template
+ static void selectSourcePair(bool is_first_const, bool is_first_nullable, FirstSource && first,
+ bool is_second_const, bool is_second_nullable, SecondSource && second,
+ ColumnUInt8 & result)
+ {
+ using SourceType = typename std::decay::type;
+
+ if (is_first_nullable)
+ {
+ using NullableSource = NullableArraySource;
+
+ if (is_first_const)
+ callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result);
+ else
+ callFunction(static_cast(first), is_second_const, is_second_nullable, second, result);
+ }
+ else
+ {
+ if (is_first_const)
+ callFunction(static_cast &>(first), is_second_const, is_second_nullable, second, result);
+ else
+ callFunction(first, is_second_const, is_second_nullable, second, result);
+ }
+ }
+};
+
+}
+
+void sliceHasStartsWith(IArraySource & first, IArraySource & second, ColumnUInt8 & result)
+{
+ ArrayStartsWithSelectArraySourcePair::select(first, second, result);
+}
+
+}
diff --git a/tests/queries/0_stateless/02206_array_starts_ends_with.reference b/tests/queries/0_stateless/02206_array_starts_ends_with.reference
new file mode 100644
index 00000000000..e0dacfc06e0
--- /dev/null
+++ b/tests/queries/0_stateless/02206_array_starts_ends_with.reference
@@ -0,0 +1,30 @@
+1
+1
+0
+-
+1
+1
+0
+1
+0
+-
+1
+0
+1
+0
+-
+1
+1
+0
+-
+1
+1
+0
+1
+0
+-
+1
+0
+-
+1
+1
diff --git a/tests/queries/0_stateless/02206_array_starts_ends_with.sql b/tests/queries/0_stateless/02206_array_starts_ends_with.sql
new file mode 100644
index 00000000000..39b02c29dc0
--- /dev/null
+++ b/tests/queries/0_stateless/02206_array_starts_ends_with.sql
@@ -0,0 +1,36 @@
+select startsWith([], []);
+select startsWith([1], []);
+select startsWith([], [1]);
+select '-';
+
+select startsWith([NULL], [NULL]);
+select startsWith([NULL], []);
+select startsWith([], [NULL]);
+select startsWith([NULL, 1], [NULL]);
+select startsWith([NULL, 1], [1]);
+select '-';
+
+select startsWith([1, 2, 3, 4], [1, 2, 3]);
+select startsWith([1, 2, 3, 4], [1, 2, 4]);
+select startsWith(['a', 'b', 'c'], ['a', 'b']);
+select startsWith(['a', 'b', 'c'], ['b']);
+select '-';
+
+select endsWith([], []);
+select endsWith([1], []);
+select endsWith([], [1]);
+select '-';
+
+select endsWith([NULL], [NULL]);
+select endsWith([NULL], []);
+select endsWith([], [NULL]);
+select endsWith([1, NULL], [NULL]);
+select endsWith([NULL, 1], [NULL]);
+select '-';
+
+select endsWith([1, 2, 3, 4], [3, 4]);
+select endsWith([1, 2, 3, 4], [3]);
+select '-';
+
+select startsWith([1], emptyArrayUInt8());
+select endsWith([1], emptyArrayUInt8());