diff --git a/docs/en/sql-reference/functions/string-search-functions.md b/docs/en/sql-reference/functions/string-search-functions.md index dba8a6e275c..3cde7dd71d4 100644 --- a/docs/en/sql-reference/functions/string-search-functions.md +++ b/docs/en/sql-reference/functions/string-search-functions.md @@ -591,3 +591,7 @@ Result: ``` [Original article](https://clickhouse.tech/docs/en/query_language/functions/string_search_functions/) + +## countMatches(haystack, pattern) {#countmatcheshaystack-pattern} + +Returns the number of regular expression matches for a `pattern` in a `haystack`. diff --git a/src/Functions/countMatches.cpp b/src/Functions/countMatches.cpp new file mode 100644 index 00000000000..935b9fb9904 --- /dev/null +++ b/src/Functions/countMatches.cpp @@ -0,0 +1,13 @@ +#include "FunctionFactory.h" +#include "countMatchesImpl.h" + + +namespace DB +{ + +void registerFunctionCountMatches(FunctionFactory & factory) +{ + factory.registerFunction(FunctionFactory::CaseInsensitive); +} + +} diff --git a/src/Functions/countMatchesImpl.h b/src/Functions/countMatchesImpl.h new file mode 100644 index 00000000000..86e25a252ce --- /dev/null +++ b/src/Functions/countMatchesImpl.h @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int BAD_ARGUMENTS; + extern const int ILLEGAL_COLUMN; +} + +using Pos = const char *; + +class FunctionCountMatches : public IFunction +{ +public: + static constexpr auto name = "countMatches"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override + { + return name; + } + + size_t getNumberOfArguments() const override { return 2; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isString(arguments[1])) + throw Exception("Illegal type " + arguments[1]->getName() + " of second argument of function " + getName() + ". Must be String.", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) const override + { + + const ColumnConst * col = checkAndGetColumnConstStringOrFixedString(block.getByPosition(arguments[1]).column.get()); + + if (!col) + throw Exception("Illegal column " + block.getByPosition(arguments[1]).column->getName() + + " of first argument of function " + getName() + ". Must be constant string.", + ErrorCodes::ILLEGAL_COLUMN); + + Regexps::Pool::Pointer re = Regexps::get(col->getValue()); + size_t capture = re->getNumberOfSubpatterns() > 0 ? 1 : 0; + OptimizedRegularExpression::MatchVec matches; + matches.resize(capture + 1); + + size_t array_argument_position = arguments[0]; + + const ColumnString * col_str = checkAndGetColumn(block.getByPosition(array_argument_position).column.get()); + const ColumnConst * col_const_str = + checkAndGetColumnConstStringOrFixedString(block.getByPosition(array_argument_position).column.get()); + + auto col_res = ColumnUInt64::create(); + ColumnUInt64::Container & vec_res = col_res->getData(); + + if (col_str) + { + const ColumnString::Chars & src_chars = col_str->getChars(); + const ColumnString::Offsets & src_offsets = col_str->getOffsets(); + + vec_res.resize(src_offsets.size()); + + size_t size = src_offsets.size(); + ColumnString::Offset current_src_offset = 0; + + for (size_t i = 0; i < size; ++i) + { + Pos pos = reinterpret_cast(&src_chars[current_src_offset]); + current_src_offset = src_offsets[i]; + Pos end = reinterpret_cast(&src_chars[current_src_offset]) - 1; + + uint64_t match_count = 0; + while (true) + { + if (!pos || pos > end) + break; + if (!re->match(pos, end - pos, matches) || !matches[0].length) + break; + pos += matches[0].offset + matches[0].length; + match_count++; + } + vec_res[i] = match_count; + } + + block.getByPosition(result).column = std::move(col_res); + } + else if (col_const_str) + { + String src = col_const_str->getValue(); + + Pos pos = reinterpret_cast(src.data()); + Pos end = reinterpret_cast(src.data() + src.size()); + + uint64_t match_count = 0; + while (true) + { + if (!pos || pos > end) + break; + if (!re->match(pos, end - pos, matches) || !matches[0].length) + break; + pos += matches[0].offset + matches[0].length; + match_count++; + } + + block.getByPosition(result).column = DataTypeUInt64().createColumnConst(col_const_str->size(), match_count); + } + else + throw Exception("Illegal columns " + block.getByPosition(array_argument_position).column->getName() + + ", " + block.getByPosition(array_argument_position).column->getName() + + " of arguments of function " + getName(), + ErrorCodes::ILLEGAL_COLUMN); + } +}; + +} diff --git a/src/Functions/registerFunctionsString.cpp b/src/Functions/registerFunctionsString.cpp index 5d4c165e1e3..647f63fe910 100644 --- a/src/Functions/registerFunctionsString.cpp +++ b/src/Functions/registerFunctionsString.cpp @@ -32,6 +32,7 @@ void registerFunctionTrim(FunctionFactory &); void registerFunctionRegexpQuoteMeta(FunctionFactory &); void registerFunctionNormalizeQuery(FunctionFactory &); void registerFunctionNormalizedQueryHash(FunctionFactory &); +void registerFunctionCountMatches(FunctionFactory &); #if USE_BASE64 void registerFunctionBase64Encode(FunctionFactory &); @@ -66,6 +67,7 @@ void registerFunctionsString(FunctionFactory & factory) registerFunctionRegexpQuoteMeta(factory); registerFunctionNormalizeQuery(factory); registerFunctionNormalizedQueryHash(factory); + registerFunctionCountMatches(factory); #if USE_BASE64 registerFunctionBase64Encode(factory); registerFunctionBase64Decode(factory);