finish dev

This commit is contained in:
taiyang-li 2023-06-27 18:13:25 +08:00
parent bd0ce5fc0b
commit 0de5fcfbee

View File

@ -1,10 +1,14 @@
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/PositionImpl.h>
#include <Interpreters/Context_fwd.h>
#include <Columns/ColumnString.h>
#include "base/find_symbols.h"
#include <base/find_symbols.h>
#include <Common/UTF8Helpers.h>
#include <Common/register_objects.h>
namespace DB
{
@ -67,7 +71,7 @@ public:
return std::make_shared<DataTypeString>();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
{
ColumnPtr column_string = arguments[0].column;
ColumnPtr column_delim = arguments[1].column;
@ -85,7 +89,8 @@ public:
}
else
{
// TODO
if (UTF8::countCodePoints(reinterpret_cast<const UInt8 *>(delim.data()), delim.size()) != 1)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument to {} must be a single UTF-8 character", getName());
}
auto column_res = ColumnString::create();
@ -96,7 +101,7 @@ public:
if (column_string_const)
{
String str = column_string_const->getValue<String>();
constantVector(str, delim[0], column_index.get(), vec_res, offsets_res);
constantVector(str, delim, column_index.get(), vec_res, offsets_res);
}
else
{
@ -108,10 +113,10 @@ public:
if (is_index_const)
{
Int64 index = column_index->getInt(0);
vectorConstant(col_str->getChars(), col_str->getOffsets(), delim[0], index, vec_res, offsets_res);
vectorConstant(col_str->getChars(), col_str->getOffsets(), delim, index, vec_res, offsets_res);
}
else
vectorVector(col_str->getChars(), col_str->getOffsets(), delim[0], column_index.get(), vec_res, offsets_res);
vectorVector(col_str->getChars(), col_str->getOffsets(), delim, column_index.get(), vec_res, offsets_res);
}
}
@ -119,7 +124,7 @@ protected:
static void vectorVector(
const ColumnString::Chars & str_data,
const ColumnString::Offsets & str_offsets,
char delim,
const String & delim,
const IColumn * index_column,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
@ -128,11 +133,15 @@ protected:
res_data.reserve(str_data.size() / 2);
res_offsets.reserve(rows);
for (size_t i=0; i<rows; ++i)
std::unique_ptr<PositionCaseSensitiveUTF8::SearcherInBigHaystack> searcher
= !is_utf8 ? nullptr : std::make_unique<PositionCaseSensitiveUTF8::SearcherInBigHaystack>(delim);
for (size_t i = 0; i < rows; ++i)
{
StringRef str_ref{&str_data[str_offsets[i]], str_offsets[i] - str_offsets[i - 1] - 1};
Int64 index = index_column->getInt(i);
StringRef res_ref = substringIndex<delim>(str_ref, index);
StringRef res_ref
= !is_utf8 ? substringIndex<delim[0]>(str_ref, index) : substringIndexUTF8(searcher.get(), str_ref, delim, index);
appendToResultColumn(res_ref, res_data, res_offsets);
}
}
@ -140,7 +149,7 @@ protected:
static void vectorConstant(
const ColumnString::Chars & str_data,
const ColumnString::Offsets & str_offsets,
char delim,
const String & delim,
Int64 index,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
@ -149,17 +158,21 @@ protected:
res_data.reserve(str_data.size() / 2);
res_offsets.reserve(rows);
std::unique_ptr<PositionCaseSensitiveUTF8::SearcherInBigHaystack> searcher
= !is_utf8 ? nullptr : std::make_unique<PositionCaseSensitiveUTF8::SearcherInBigHaystack>(delim);
for (size_t i = 0; i<rows; ++i)
{
StringRef str_ref{&str_data[str_offsets[i]], str_offsets[i] - str_offsets[i - 1] - 1};
StringRef res_ref = substringIndex<delim>(str_ref, index);
StringRef res_ref
= !is_utf8 ? substringIndex<delim[0]>(str_ref, index) : substringIndexUTF8(searcher.get(), str_ref, delim, index);
appendToResultColumn(res_ref, res_data, res_offsets);
}
}
static void constantVector(
const String & str,
char delim,
const String & delim,
const IColumn * index_column,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
@ -168,11 +181,15 @@ protected:
res_data.reserve(str.size() * rows / 2);
res_offsets.reserve(rows);
std::unique_ptr<PositionCaseSensitiveUTF8::SearcherInBigHaystack> searcher
= !is_utf8 ? nullptr : std::make_unique<PositionCaseSensitiveUTF8::SearcherInBigHaystack>(delim);
StringRef str_ref{str.data(), str.size()};
for (size_t i=0; i<rows; ++i)
{
Int64 index = index_column->getInt(i);
StringRef res_ref = substringIndex<delim>(str_ref, index);
StringRef res_ref
= !is_utf8 ? substringIndex<delim[0]>(str_ref, index) : substringIndexUTF8(searcher.get(), str_ref, delim, index);
appendToResultColumn(res_ref, res_data, res_offsets);
}
}
@ -190,18 +207,68 @@ protected:
res_offsets.emplace_back(res_offset);
}
static StringRef substringIndexUTF8(
const PositionCaseSensitiveUTF8::SearcherInBigHaystack * searcher, const StringRef & str_ref, const String & delim, Int64 index)
{
if (index == 0)
return {str_ref.data, 0};
const auto * begin = reinterpret_cast<const UInt8 *>(str_ref.data);
const auto * end = reinterpret_cast<const UInt8 *>(str_ref.data + str_ref.size);
const auto * pos = begin;
if (index > 0)
{
Int64 i = 0;
while (i < index)
{
pos = searcher->search(pos, end - pos);
if (pos != end)
{
pos += delim.size();
++i;
}
else
return str_ref;
}
return {begin, static_cast<size_t>(pos - begin - delim.size())};
}
else
{
Int64 total = 0;
while (pos < end && end != (pos = searcher->search(pos, end - pos)))
{
pos += delim.size();
++total;
}
if (total + index < 0)
return str_ref;
Int64 index_from_left = total + 1 + index;
pos = begin;
Int64 i = 0;
while (pos < end && end != (pos = searcher->search(pos, end - pos)) && i < index_from_left)
{
pos += delim.size();
++i;
}
return {pos, static_cast<size_t>(end - pos)};
}
}
template <char delim>
static StringRef substringIndex(
const StringRef & str,
const StringRef & str_ref,
Int64 index)
{
if (index == 0)
return {str.data, 0};
return {str_ref.data, 0};
if (index > 0)
{
const auto * end = str.data + str.size;
const auto * pos = str.data;
const auto * end = str_ref.data + str_ref.size;
const auto * pos = str_ref.data;
Int64 i = 0;
while (i < index)
{
@ -213,18 +280,18 @@ protected:
++i;
}
else
return str;
return str_ref;
}
return {str.data, static_cast<size_t>(pos - str.data)};
return {str_ref.data, static_cast<size_t>(pos - str_ref.data - 1)};
}
else
{
const auto * begin = str.data;
const auto * pos = str.data + str.size;
const auto * begin = str_ref.data;
const auto * pos = str_ref.data + str_ref.size;
Int64 i = 0;
while (i < index)
{
const auto * next_pos = detail::find_last_symbols_sse2<true, detail::ReturnMode::End, delim>(begin, pos);
const auto * next_pos = ::detail::find_last_symbols_sse2<true, ::detail::ReturnMode::End, delim>(begin, pos);
if (next_pos != pos)
{
@ -232,14 +299,24 @@ protected:
++i;
}
else
return str;
return str_ref;
}
return {pos + 1, static_cast<size_t>(str.data + str.size - pos - 1)};
return {pos + 1, static_cast<size_t>(str_ref.data + str_ref.size - pos - 1)};
}
}
};
}
REGISTER_FUNCTION(SubstringIndex)
{
factory.registerFunction<FunctionSubstringIndex<false>>(); /// substringIndex
factory.registerFunction<FunctionSubstringIndex<true>>(); /// substringIndexUTF8
factory.registerAlias("SUBSTRING_INDEX", "substringIndex", FunctionFactory::CaseInsensitive);
}
}