This commit is contained in:
Nikolay Degterinsky 2022-01-10 15:36:32 +00:00
parent fce10091a9
commit 85b8985df2
12 changed files with 241 additions and 507 deletions

View File

@ -36,4 +36,4 @@ endif ()
set (USE_NLP 1) set (USE_NLP 1)
message (STATUS "Using Libraries for NLP functions: contrib/wordnet-blast, contrib/libstemmer_c, contrib/lemmagen-c") message (STATUS "Using Libraries for NLP functions: contrib/wordnet-blast, contrib/libstemmer_c, contrib/lemmagen-c, contrib/cld2")

2
contrib/nlp-data vendored

@ -1 +1 @@
Subproject commit 3bc8aef8440b66823186f47a74996cbdad66c04f Subproject commit 5591f91f5e748cba8fb9ef81564176feae774853

View File

@ -0,0 +1,14 @@
include(${ClickHouse_SOURCE_DIR}/cmake/embed_binary.cmake)
set(LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/nlp-data")
add_library (nlp_data INTERFACE)
clickhouse_embed_binaries(
TARGET nlp_dicts
RESOURCE_DIR "${LIBRARY_DIR}"
RESOURCES charset.zst tonality_ru.zst programming.zst
)
add_dependencies(nlp_data nlp_dicts)
target_link_libraries(nlp_data INTERFACE "-Wl,${WHOLE_ARCHIVE} $<TARGET_FILE:nlp_dicts> -Wl,${NO_WHOLE_ARCHIVE}")

View File

@ -1,22 +1,20 @@
#pragma once #pragma once
#include <Common/Arena.h>
#include <Common/getResource.h>
#include <Common/HashTable/HashMap.h>
#include <Common/StringUtils/StringUtils.h> #include <Common/StringUtils/StringUtils.h>
#include <IO/ReadBufferFromFile.h> #include <IO/ReadBufferFromFile.h>
#include <IO/ReadBufferFromString.h> #include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/readFloatText.h> #include <IO/readFloatText.h>
#include <IO/Operators.h>
#include <IO/ZstdInflatingReadBuffer.h> #include <IO/ZstdInflatingReadBuffer.h>
#include <Common/Arena.h>
#include <base/StringRef.h> #include <base/StringRef.h>
#include <Common/HashTable/HashMap.h> #include <base/logger_useful.h>
#include <string_view> #include <string_view>
#include <string>
#include <cstring>
#include <unordered_map> #include <unordered_map>
#include <base/logger_useful.h>
#include <Common/getResource.h>
namespace DB namespace DB
{ {
@ -26,6 +24,14 @@ namespace ErrorCodes
extern const int FILE_DOESNT_EXIST; extern const int FILE_DOESNT_EXIST;
} }
/// FrequencyHolder class is responsible for storing and loading dictionaries
/// needed for text classification functions:
///
/// 1. detectLanguageUnknown
/// 2. detectCharset
/// 3. detectTonality
/// 4. detectProgrammingLanguage
class FrequencyHolder class FrequencyHolder
{ {
@ -39,6 +45,7 @@ public:
struct Encoding struct Encoding
{ {
String name; String name;
String lang;
HashMap<UInt16, Float64> map; HashMap<UInt16, Float64> map;
}; };
@ -54,14 +61,13 @@ public:
return instance; return instance;
} }
void loadEncodingsFrequency() void loadEncodingsFrequency()
{ {
Poco::Logger * log = &Poco::Logger::get("EncodingsFrequency"); Poco::Logger * log = &Poco::Logger::get("EncodingsFrequency");
LOG_TRACE(log, "Loading embedded charset frequencies"); LOG_TRACE(log, "Loading embedded charset frequencies");
auto resource = getResource("charset_freq.txt.zst"); auto resource = getResource("charset.zst");
if (resource.empty()) if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded charset frequencies"); throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded charset frequencies");
@ -71,12 +77,12 @@ public:
String charset_name; String charset_name;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size()); auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
std::unique_ptr<ReadBuffer> in = std::make_unique<ZstdInflatingReadBuffer>(std::move(buf)); ZstdInflatingReadBuffer in(std::move(buf));
while (!in->eof()) while (!in.eof())
{ {
readString(line, *in); readString(line, in);
++in->position(); in.ignore();
if (line.empty()) if (line.empty())
continue; continue;
@ -84,13 +90,21 @@ public:
ReadBufferFromString buf_line(line); ReadBufferFromString buf_line(line);
// Start loading a new charset // Start loading a new charset
if (line.starts_with("//")) if (line.starts_with("// "))
{ {
// Skip "// "
buf_line.ignore(3); buf_line.ignore(3);
readString(charset_name, buf_line); readString(charset_name, buf_line);
/* In our dictionary we have lines with form: <Language>_<Charset>
* If we need to find language of data, we return <Language>
* If we need to find charset of data, we return <Charset>.
*/
size_t sep = charset_name.find('_');
Encoding enc; Encoding enc;
enc.name = charset_name; enc.lang = charset_name.substr(0, sep);
enc.name = charset_name.substr(sep + 1);
encodings_freq.push_back(std::move(enc)); encodings_freq.push_back(std::move(enc));
} }
else else
@ -109,9 +123,9 @@ public:
void loadEmotionalDict() void loadEmotionalDict()
{ {
Poco::Logger * log = &Poco::Logger::get("EmotionalDict"); Poco::Logger * log = &Poco::Logger::get("EmotionalDict");
LOG_TRACE(log, "Loading embedded emotional dictionary (RU)"); LOG_TRACE(log, "Loading embedded emotional dictionary");
auto resource = getResource("emotional_dictionary_rus.txt.zst"); auto resource = getResource("tonality_ru.zst");
if (resource.empty()) if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded emotional dictionary"); throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded emotional dictionary");
@ -121,12 +135,12 @@ public:
size_t count = 0; size_t count = 0;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size()); auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
std::unique_ptr<ReadBuffer> in = std::make_unique<ZstdInflatingReadBuffer>(std::move(buf)); ZstdInflatingReadBuffer in(std::move(buf));
while (!in->eof()) while (!in.eof())
{ {
readString(line, *in); readString(line, in);
++in->position(); in.ignore();
if (line.empty()) if (line.empty())
continue; continue;
@ -151,7 +165,7 @@ public:
LOG_TRACE(log, "Loading embedded programming languages frequencies loading"); LOG_TRACE(log, "Loading embedded programming languages frequencies loading");
auto resource = getResource("prog_freq.txt.zst"); auto resource = getResource("programming.zst");
if (resource.empty()) if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded programming languages frequencies"); throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded programming languages frequencies");
@ -161,12 +175,12 @@ public:
String programming_language; String programming_language;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size()); auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
std::unique_ptr<ReadBuffer> in = std::make_unique<ZstdInflatingReadBuffer>(std::move(buf)); ZstdInflatingReadBuffer in(std::move(buf));
while (!in->eof()) while (!in.eof())
{ {
readString(line, *in); readString(line, in);
++in->position(); in.ignore();
if (line.empty()) if (line.empty())
continue; continue;
@ -174,8 +188,9 @@ public:
ReadBufferFromString buf_line(line); ReadBufferFromString buf_line(line);
// Start loading a new language // Start loading a new language
if (line.starts_with("//")) if (line.starts_with("// "))
{ {
// Skip "// "
buf_line.ignore(3); buf_line.ignore(3);
readString(programming_language, buf_line); readString(programming_language, buf_line);

View File

@ -1,18 +1,18 @@
#include <Functions/FunctionsTextClassification.h>
#include <Common/FrequencyHolder.h> #include <Common/FrequencyHolder.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Common/UTF8Helpers.h> #include <Functions/FunctionStringToString.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <cstring>
#include <cmath>
#include <unordered_map>
#include <memory> #include <memory>
#include <utility> #include <unordered_map>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
}
/* Determine language and charset of text data. For each text, we build the distribution of bigrams bytes. /* Determine language and charset of text data. For each text, we build the distribution of bigrams bytes.
* Then we use marked-up dictionaries with distributions of bigram bytes of various languages and charsets. * Then we use marked-up dictionaries with distributions of bigram bytes of various languages and charsets.
* Using a naive Bayesian classifier, find the most likely charset and language and return it * Using a naive Bayesian classifier, find the most likely charset and language and return it
@ -21,10 +21,6 @@ namespace DB
template <size_t N, bool detect_language> template <size_t N, bool detect_language>
struct CharsetClassificationImpl struct CharsetClassificationImpl
{ {
using ResultType = String;
using CodePoint = UInt8;
/* We need to solve zero-frequency problem for Naive Bayes Classifier /* We need to solve zero-frequency problem for Naive Bayes Classifier
* If the bigram is not found in the text, we assume that the probability of its meeting is 1e-06. * If the bigram is not found in the text, we assume that the probability of its meeting is 1e-06.
* 1e-06 is minimal value in our marked-up dictionary. * 1e-06 is minimal value in our marked-up dictionary.
@ -34,28 +30,22 @@ struct CharsetClassificationImpl
/// If the data size is bigger than this, behaviour is unspecified for this function. /// If the data size is bigger than this, behaviour is unspecified for this function.
static constexpr size_t max_string_size = 1u << 15; static constexpr size_t max_string_size = 1u << 15;
/// Default padding to read safely. static ALWAYS_INLINE inline Float64 naiveBayes(
static constexpr size_t default_padding = 16; const FrequencyHolder::EncodingMap & standard,
const HashMap<UInt16, Float64> & model,
/// Max codepoints to store at once. 16 is for batching usage and PODArray has this padding.
static constexpr size_t simultaneously_codepoints_num = default_padding + N - 1;
static ALWAYS_INLINE inline Float64 naiveBayes(const FrequencyHolder::EncodingMap & standard,
std::unordered_map<UInt16, Float64> & model,
Float64 max_result) Float64 max_result)
{ {
Float64 res = 0; Float64 res = 0;
for (auto & el : model) for (const auto & el : model)
{ {
/// Try to find bigram in the dictionary. /// Try to find bigram in the dictionary.
auto it = standard.find(el.first); const auto * it = standard.find(el.getKey());
if (it != standard.end()) if (it != standard.end())
{ {
res += el.second * log(it->getMapped()); res += el.getMapped() * log(it->getMapped());
} else } else
{ {
res += el.second * log(zero_frequency); res += el.getMapped() * log(zero_frequency);
} }
/// If at some step the result has become less than the current maximum, then it makes no sense to count it fully. /// If at some step the result has become less than the current maximum, then it makes no sense to count it fully.
if (res < max_result) if (res < max_result)
@ -66,95 +56,21 @@ struct CharsetClassificationImpl
return res; return res;
} }
static ALWAYS_INLINE size_t readCodePoints(CodePoint * code_points, const char *& pos, const char * end)
{
constexpr size_t padding_offset = default_padding - N + 1;
memcpy(code_points, code_points + padding_offset, roundUpToPowerOfTwoOrZero(N - 1) * sizeof(CodePoint));
memcpy(code_points + (N - 1), pos, default_padding * sizeof(CodePoint));
pos += padding_offset;
if (pos > end)
return default_padding - (pos - end);
return default_padding;
}
/// Сount how many times each bigram occurs in the text. /// Сount how many times each bigram occurs in the text.
static ALWAYS_INLINE inline size_t calculateStats( static ALWAYS_INLINE inline void calculateStats(
const char * data, const UInt8 * data,
const size_t size, const size_t size,
size_t (*read_code_points)(CodePoint *, const char *&, const char *), HashMap<UInt16, Float64> & model)
std::unordered_map<UInt16, Float64>& model)
{ {
UInt16 hash = 0;
const char * start = data; for (size_t i = 0; i < size; ++i)
const char * end = data + size;
CodePoint cp[simultaneously_codepoints_num] = {};
/// read_code_points returns the position of cp where it stopped reading codepoints.
size_t found = read_code_points(cp, start, end);
/// We need to start for the first time here, because first N - 1 codepoints mean nothing.
size_t i = N - 1;
size_t len = 0;
do
{ {
for (; i + N <= found; ++i) hash <<= 8;
{ hash += *(data + i);
UInt32 hash = 0; ++model[hash];
for (size_t j = 0; j < N; ++j)
{
hash <<= 8;
hash += *(cp + i + j);
}
if (model[hash] == 0)
{
model[hash] = 1;
++len;
}
++model[hash];
}
i = 0;
} while (start < end && (found = read_code_points(cp, start, end)));
return len;
}
static void constant(String data, String & res)
{
const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency();
std::unordered_map<UInt16, Float64> model;
calculateStats(data.data(), data.size(), readCodePoints, model);
Float64 max_result = log(zero_frequency) * (max_string_size);
String poss_ans;
/// Go through the dictionary and find the charset with the highest weight
for (auto& item : encodings_freq)
{
Float64 score = naiveBayes(item.map, model, max_result);
if (max_result < score)
{
poss_ans = item.name;
max_result = score;
}
}
/* In our dictionary we have lines with form: <Language>_<Charset>
* If we need to find language of data, we return <Language>
* If we need to find charset of data, we return <Charset>.
*/
size_t sep = poss_ans.find('_');
if (detect_language)
{
res = poss_ans.erase(0, sep + 1);
}
else
{
res = poss_ans.erase(sep, poss_ans.size() - sep);
} }
} }
static void vector( static void vector(
const ColumnString::Chars & data, const ColumnString::Chars & data,
const ColumnString::Offsets & offsets, const ColumnString::Offsets & offsets,
@ -163,64 +79,53 @@ struct CharsetClassificationImpl
{ {
const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency(); const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency();
res_data.reserve(1024); if (detect_language)
/// 2 chars for ISO code + 1 zero byte
res_data.reserve(offsets.size() * 3);
else
/// Mean charset length is 8
res_data.reserve(offsets.size() * 8);
res_offsets.resize(offsets.size()); res_offsets.resize(offsets.size());
size_t prev_offset = 0;
size_t res_offset = 0; size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) for (size_t i = 0; i < offsets.size(); ++i)
{ {
const char * haystack = reinterpret_cast<const char *>(&data[prev_offset]); const UInt8 * str = data.data() + offsets[i - 1];
String str = haystack; const size_t str_len = offsets[i] - offsets[i - 1] - 1;
String poss_ans; std::string_view res;
std::unordered_map<UInt16, Float64> model; HashMap<UInt16, Float64> model;
calculateStats(str.data(), str.size(), readCodePoints, model); calculateStats(str, str_len, model);
Float64 max_result = log(zero_frequency) * (max_string_size); /// Go through the dictionary and find the charset with the highest weight
for (auto& item : encodings_freq) Float64 max_result = log(zero_frequency) * (max_string_size);
for (const auto & item : encodings_freq)
{ {
Float64 score = naiveBayes(item.map, model, max_result); Float64 score = naiveBayes(item.map, model, max_result);
if (max_result < score) if (max_result < score)
{ {
max_result = score; max_result = score;
poss_ans = item.name; res = detect_language ? item.lang : item.name;
} }
} }
size_t sep = poss_ans.find('_'); res_data.resize(res_offset + res.size() + 1);
String ans_str; memcpy(&res_data[res_offset], res.data(), res.size());
if (detect_language) res_data[res_offset + res.size()] = 0;
{ res_offset += res.size() + 1;
ans_str = poss_ans.erase(0, sep + 1);
}
else
{
ans_str = poss_ans.erase(sep, poss_ans.size() - sep);
}
ans_str = poss_ans;
const auto res = ans_str.c_str();
size_t cur_offset = offsets[i];
size_t ans_size = strlen(res);
res_data.resize(res_offset + ans_size + 1);
memcpy(&res_data[res_offset], res, ans_size);
res_offset += ans_size;
res_data[res_offset] = 0;
++res_offset;
res_offsets[i] = res_offset; res_offsets[i] = res_offset;
prev_offset = cur_offset;
} }
} }
[[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &)
{
throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN);
}
}; };
@ -235,8 +140,8 @@ struct NameLanguageDetect
}; };
using FunctionCharsetDetect = FunctionsTextClassification<CharsetClassificationImpl<2, true>, NameCharsetDetect>; using FunctionCharsetDetect = FunctionStringToString<CharsetClassificationImpl<2, false>, NameCharsetDetect, false>;
using FunctionLanguageDetect = FunctionsTextClassification<CharsetClassificationImpl<2, false>, NameLanguageDetect>; using FunctionLanguageDetect = FunctionStringToString<CharsetClassificationImpl<2, true>, NameLanguageDetect, false>;
void registerFunctionsCharsetClassification(FunctionFactory & factory) void registerFunctionsCharsetClassification(FunctionFactory & factory)
{ {

View File

@ -4,7 +4,7 @@
#if USE_NLP #if USE_NLP
#include <Functions/FunctionsTextClassification.h> #include <Functions/FunctionStringToString.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeMap.h> #include <DataTypes/DataTypeMap.h>
@ -34,9 +34,7 @@ extern const int ILLEGAL_COLUMN;
struct LanguageClassificationImpl struct LanguageClassificationImpl
{ {
using ResultType = String; static std::string_view codeISO(std::string_view code_string)
static String codeISO(std::string_view code_string)
{ {
if (code_string.ends_with("-Latn")) if (code_string.ends_with("-Latn"))
code_string.remove_suffix(code_string.size() - 5); code_string.remove_suffix(code_string.size() - 5);
@ -61,54 +59,44 @@ struct LanguageClassificationImpl
if (code_string.size() != 2) if (code_string.size() != 2)
return "other"; return "other";
return String(code_string); return code_string;
} }
static void constant(const String & data, String & res)
{
bool is_reliable = true;
const char * str = data.c_str();
auto lang = CLD2::DetectLanguage(str, strlen(str), true, &is_reliable);
res = codeISO(LanguageCode(lang));
}
static void vector( static void vector(
const ColumnString::Chars & data, const ColumnString::Chars & data,
const ColumnString::Offsets & offsets, const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data, ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets) ColumnString::Offsets & res_offsets)
{ {
res_data.reserve(1024); /// Constant 3 is based on the fact that in general we need 2 characters for ISO code + 1 zero byte
res_data.reserve(offsets.size() * 3);
res_offsets.resize(offsets.size()); res_offsets.resize(offsets.size());
size_t prev_offset = 0; bool is_reliable = true;
size_t res_offset = 0; size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) for (size_t i = 0; i < offsets.size(); ++i)
{ {
const char * str = reinterpret_cast<const char *>(&data[prev_offset]); const char * str = reinterpret_cast<const char *>(data.data() + offsets[i - 1]);
String res; const size_t str_len = offsets[i] - offsets[i - 1] - 1;
bool is_reliable = true;
auto lang = CLD2::DetectLanguage(str, strlen(str), true, &is_reliable); auto lang = CLD2::DetectLanguage(str, str_len, true, &is_reliable);
res = codeISO(LanguageCode(lang)); auto res = codeISO(LanguageCode(lang));
size_t cur_offset = offsets[i];
res_data.resize(res_offset + res.size() + 1); res_data.resize(res_offset + res.size() + 1);
memcpy(&res_data[res_offset], res.data(), res.size()); memcpy(&res_data[res_offset], res.data(), res.size());
res_offset += res.size();
res_data[res_offset] = 0; res_data[res_offset + res.size()] = 0;
++res_offset; res_offset += res.size() + 1;
res_offsets[i] = res_offset; res_offsets[i] = res_offset;
prev_offset = cur_offset;
} }
} }
[[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &)
{
throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN);
}
}; };
class LanguageClassificationMixedDetect : public IFunction class LanguageClassificationMixedDetect : public IFunction
@ -116,6 +104,9 @@ class LanguageClassificationMixedDetect : public IFunction
public: public:
static constexpr auto name = "detectLanguageMixed"; static constexpr auto name = "detectLanguageMixed";
/// Number of top results
static constexpr auto top_N = 3;
static FunctionPtr create(ContextPtr) { return std::make_shared<LanguageClassificationMixedDetect>(); } static FunctionPtr create(ContextPtr) { return std::make_shared<LanguageClassificationMixedDetect>(); }
String getName() const override { return name; } String getName() const override { return name; }
@ -132,7 +123,7 @@ public:
throw Exception( throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeMap>(std::make_shared<DataTypeString>(), std::make_shared<DataTypeInt32>()); return std::make_shared<DataTypeMap>(std::make_shared<DataTypeString>(), std::make_shared<DataTypeFloat32>());
} }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
@ -145,8 +136,8 @@ public:
"Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(), "Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN); ErrorCodes::ILLEGAL_COLUMN);
auto & input_data = col->getChars(); const auto & input_data = col->getChars();
auto & input_offsets = col->getOffsets(); const auto & input_offsets = col->getOffsets();
/// Create and fill the result map. /// Create and fill the result map.
@ -158,15 +149,15 @@ public:
MutableColumnPtr values_data = value_type->createColumn(); MutableColumnPtr values_data = value_type->createColumn();
MutableColumnPtr offsets = DataTypeNumber<IColumn::Offset>().createColumn(); MutableColumnPtr offsets = DataTypeNumber<IColumn::Offset>().createColumn();
size_t total_elements = input_rows_count * 3; size_t total_elements = input_rows_count * top_N;
keys_data->reserve(total_elements); keys_data->reserve(total_elements);
values_data->reserve(total_elements); values_data->reserve(total_elements);
offsets->reserve(input_rows_count); offsets->reserve(input_rows_count);
bool is_reliable = true; bool is_reliable = true;
CLD2::Language result_lang_top3[3]; CLD2::Language result_lang_top3[top_N];
int32_t pc[3]; int32_t pc[top_N];
int bytes[3]; int bytes[top_N];
IColumn::Offset current_offset = 0; IColumn::Offset current_offset = 0;
for (size_t i = 0; i < input_rows_count; ++i) for (size_t i = 0; i < input_rows_count; ++i)
@ -176,16 +167,16 @@ public:
CLD2::DetectLanguageSummary(str, str_len, true, result_lang_top3, pc, bytes, &is_reliable); CLD2::DetectLanguageSummary(str, str_len, true, result_lang_top3, pc, bytes, &is_reliable);
for (size_t j = 0; j < 3; ++j) for (size_t j = 0; j < top_N; ++j)
{ {
auto res_str = LanguageClassificationImpl::codeISO(LanguageCode(result_lang_top3[j])); auto res_str = LanguageClassificationImpl::codeISO(LanguageCode(result_lang_top3[j]));
int32_t res_int = static_cast<int>(pc[j]); Float32 res_float = static_cast<Float32>(pc[j]) / 100;
keys_data->insertData(res_str.data(), res_str.size()); keys_data->insertData(res_str.data(), res_str.size());
values_data->insertData(reinterpret_cast<const char *>(&res_int), sizeof(res_int)); values_data->insertData(reinterpret_cast<const char *>(&res_float), sizeof(res_float));
} }
current_offset += 3; current_offset += top_N;
offsets->insert(current_offset); offsets->insert(current_offset);
} }
@ -203,7 +194,7 @@ struct NameLanguageUTF8Detect
}; };
using FunctionLanguageUTF8Detect = FunctionsTextClassification<LanguageClassificationImpl, NameLanguageUTF8Detect>; using FunctionLanguageUTF8Detect = FunctionStringToString<LanguageClassificationImpl, NameLanguageUTF8Detect, false>;
void registerFunctionLanguageDetectUTF8(FunctionFactory & factory) void registerFunctionLanguageDetectUTF8(FunctionFactory & factory)
{ {

View File

@ -1,12 +1,18 @@
#include <Functions/FunctionsTextClassification.h>
#include <Common/FrequencyHolder.h> #include <Common/FrequencyHolder.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <IO/ReadHelpers.h> #include <Functions/FunctionStringToString.h>
#include <unordered_map> #include <unordered_map>
#include <string_view>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
}
/** /**
* Determine the programming language from the source code. * Determine the programming language from the source code.
* We calculate all the unigrams and bigrams of commands in the source code. * We calculate all the unigrams and bigrams of commands in the source code.
@ -15,16 +21,16 @@ namespace DB
*/ */
struct ProgrammingClassificationImpl struct ProgrammingClassificationImpl
{ {
using ResultType = String;
/// Calculate total weight /// Calculate total weight
static ALWAYS_INLINE inline Float64 stateMachine(const FrequencyHolder::Map & standard, std::unordered_map<String, Float64> & model) static ALWAYS_INLINE inline Float64 stateMachine(
const FrequencyHolder::Map & standard,
const std::unordered_map<String, Float64> & model)
{ {
Float64 res = 0; Float64 res = 0;
for (auto & el : model) for (const auto & el : model)
{ {
/// Try to find each n-gram in dictionary /// Try to find each n-gram in dictionary
auto it = standard.find(el.first); const auto * it = standard.find(el.first);
if (it != standard.end()) if (it != standard.end())
{ {
res += el.second * it->getMapped(); res += el.second * it->getMapped();
@ -33,104 +39,44 @@ struct ProgrammingClassificationImpl
return res; return res;
} }
static void constant(String data, String & res)
{
auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency();
std::unordered_map<String, Float64> data_freq;
String prev_command;
String command;
/// Select all commands from the string
for (size_t i = 0; i < data.size();)
{
/// Assume that all commands are split by spaces
if (!isspace(data[i]))
{
command.push_back(data[i]);
++i;
while ((i < data.size()) && (!isspace(data[i])))
{
command.push_back(data[i]);
++i;
}
if (prev_command == "")
{
prev_command = command;
}
else
{
data_freq[prev_command + command] += 1;
data_freq[prev_command] += 1;
prev_command = command;
}
command = "";
}
else
{
++i;
}
}
String most_liked;
Float64 max_result = 0;
/// Iterate over all programming languages and find the language with the highest weight
for (auto& item : programming_freq)
{
Float64 result = stateMachine(item.map, data_freq);
if (result > max_result)
{
max_result = result;
most_liked = item.name;
}
}
/// If all weights are zero, then we assume that the language is undefined
if (most_liked == "")
{
most_liked = "Undefined";
}
res = most_liked;
}
static void vector( static void vector(
const ColumnString::Chars & data, const ColumnString::Chars & data,
const ColumnString::Offsets & offsets, const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data, ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets) ColumnString::Offsets & res_offsets)
{ {
auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency(); const auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency();
res_data.reserve(1024); /// Constant 5 is arbitrary
res_data.reserve(offsets.size() * 5);
res_offsets.resize(offsets.size()); res_offsets.resize(offsets.size());
size_t prev_offset = 0;
size_t res_offset = 0; size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) for (size_t i = 0; i < offsets.size(); ++i)
{ {
const char * haystack = reinterpret_cast<const char *>(&data[prev_offset]); const UInt8 * str = data.data() + offsets[i - 1];
const size_t str_len = offsets[i] - offsets[i - 1] - 1;
std::unordered_map<String, Float64> data_freq; std::unordered_map<String, Float64> data_freq;
String str_data = haystack;
String prev_command; String prev_command;
String command; String command;
/// Select all commands from the string /// Select all commands from the string
for (size_t ind = 0; ind < str_data.size();) for (size_t ind = 0; ind < str_len;)
{ {
/// Assume that all commands are split by spaces /// Assume that all commands are split by spaces
if (!isspace(str_data[ind])) if (!isspace(str[ind]))
{ {
command.push_back(str_data[ind]); command.push_back(str[ind]);
++ind; ++ind;
while ((ind < str_data.size()) && (!isspace(str_data[ind]))) while ((ind < str_len) && (!isspace(str[ind])))
{ {
command.push_back(str_data[ind]); command.push_back(str[ind]);
++ind; ++ind;
} }
if (prev_command == "") if (prev_command.empty())
{ {
prev_command = command; prev_command = command;
} }
@ -148,39 +94,36 @@ struct ProgrammingClassificationImpl
} }
} }
String most_liked; String res;
Float64 max_result = 0; Float64 max_result = 0;
/// Iterate over all programming languages and find the language with the highest weight /// Iterate over all programming languages and find the language with the highest weight
for (auto& item : programming_freq) for (const auto & item : programming_freq)
{ {
Float64 result = stateMachine(item.map, data_freq); Float64 result = stateMachine(item.map, data_freq);
if (result > max_result) if (result > max_result)
{ {
max_result = result; max_result = result;
most_liked = item.name; res = item.name;
} }
} }
/// If all weights are zero, then we assume that the language is undefined /// If all weights are zero, then we assume that the language is undefined
if (most_liked == "") if (res.empty())
{ res = "Undefined";
most_liked = "Undefined";
}
const auto res = most_liked.c_str(); res_data.resize(res_offset + res.size() + 1);
size_t cur_offset = offsets[i]; memcpy(&res_data[res_offset], res.data(), res.size());
size_t ans_size = strlen(res);
res_data.resize(res_offset + ans_size + 1);
memcpy(&res_data[res_offset], res, ans_size);
res_offset += ans_size;
res_data[res_offset] = 0; res_data[res_offset + res.size()] = 0;
++res_offset; res_offset += res.size() + 1;
res_offsets[i] = res_offset; res_offsets[i] = res_offset;
prev_offset = cur_offset;
} }
} }
[[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &)
{
throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN);
}
}; };
struct NameGetProgramming struct NameGetProgramming
@ -189,7 +132,7 @@ struct NameGetProgramming
}; };
using FunctionGetProgramming = FunctionsTextClassification<ProgrammingClassificationImpl, NameGetProgramming>; using FunctionGetProgramming = FunctionStringToString<ProgrammingClassificationImpl, NameGetProgramming, false>;
void registerFunctionsProgrammingClassification(FunctionFactory & factory) void registerFunctionsProgrammingClassification(FunctionFactory & factory)
{ {

View File

@ -1,86 +0,0 @@
#pragma once
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
/** Functions for text classification:
*
* detectCharset(string data) - detect charset of data.
* Returns string name of most likely charset.
*
* detectLanguage(string data) - detect language of data in various encodings (not UTF-8)
*
* getTonality(string data) - defines the emotional coloring of the text.
* Returns NEG if text is negative, POS if text is positive or NEUT if text is neutral.
*
* getProgrammingLanguage(string data) - detect programming language
*/
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
}
template <typename Impl, typename Name>
class FunctionsTextClassification : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionsTextClassification>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return arguments[0];
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override
{
using ResultType = typename Impl::ResultType;
const ColumnPtr & column = arguments[0].column;
const ColumnConst * col_const = typeid_cast<const ColumnConst *>(&*column);
if (col_const)
{
ResultType res;
Impl::constant(col_const->getValue<String>(), res);
return result_type->createColumnConst(col_const->size(), toField(res));
}
if (const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()))
{
auto col_res = ColumnString::create();
ColumnString::Chars & vec_res = col_res->getChars();
ColumnString::Offsets & offsets_res = col_res->getOffsets();
Impl::vector(col->getChars(), col->getOffsets(), vec_res, offsets_res);
return col_res;
}
else
{
throw Exception(
"Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
}
};
}

View File

@ -1,163 +1,115 @@
#include <Functions/FunctionsTextClassification.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/FrequencyHolder.h> #include <Common/FrequencyHolder.h>
#include <Common/StringUtils/StringUtils.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Common/UTF8Helpers.h> #include <Functions/FunctionStringOrArrayToT.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <unordered_map> #include <unordered_map>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/** /**
* Determines the sentiment of text data. * Determines the sentiment of text data.
* Uses a marked-up sentiment dictionary, each word has a tonality ranging from -3 to 3. * Uses a marked-up sentiment dictionary, each word has a tonality ranging from -12 to 6.
* For each text, calculate the average sentiment value of its words and return NEG, POS or NEUT * For each text, calculate the average sentiment value of its words and return NEG, POS or NEUT
*/ */
struct TonalityClassificationImpl struct TonalityClassificationImpl
{ {
static Float32 detectTonality(const UInt8 * str, const size_t str_len, const FrequencyHolder::Map & emotional_dict)
using ResultType = String;
static String get_tonality(const Float64 & tonality_level)
{ {
if (tonality_level < 0.15) { return "NEG"; }
if (tonality_level > 0.45) { return "POS"; }
return "NEUT";
}
static void constant(String data, String & res)
{
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
Float64 weight = 0; Float64 weight = 0;
Float64 count_words = 0; UInt64 count_words = 0;
String answer;
String word; String word;
/// Select all Russian words from the string /// Select all Russian words from the string
for (size_t i = 0; i < data.size();) for (size_t ind = 0; ind < str_len;)
{ {
/// Assume that all non-Ascii characters are Russian letters /// Assume that all non-ASCII characters are Russian letters
if (!isASCII(data[i])) if (!isASCII(str[ind]))
{ {
word.push_back(data[i]); word.push_back(str[ind]);
++i; ++ind;
while ((i < data.size()) && (!isASCII(data[i]))) while ((ind < str_len) && (!isASCII(str[ind])))
{ {
word.push_back(data[i]); word.push_back(str[ind]);
++i; ++ind;
} }
/// Try to find a russian word in the tonality dictionary /// Try to find a russian word in the tonality dictionary
auto it = emotional_dict.find(word); const auto * it = emotional_dict.find(word);
if (it != emotional_dict.end()) if (it != emotional_dict.end())
{ {
count_words += 1; count_words += 1;
weight += it->getMapped(); weight += it->getMapped();
} }
word = ""; word.clear();
} }
else else
{ {
++i; ++ind;
} }
} }
/// Calculate average value of tonality /// Calculate average value of tonality.
Float64 total_tonality = weight / count_words; /// Convert values -12..6 to -1..1
res += get_tonality(total_tonality); return std::max(weight / count_words / 6, -1.0);
} }
/// If the function will return constant value for FixedString data type.
static constexpr auto is_fixed_to_constant = false;
static void vector( static void vector(
const ColumnString::Chars & data, const ColumnString::Chars & data,
const ColumnString::Offsets & offsets, const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data, PaddedPODArray<Float32> & res)
ColumnString::Offsets & res_offsets)
{ {
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict(); const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
res_data.reserve(1024); size_t size = offsets.size();
res_offsets.resize(offsets.size());
size_t prev_offset = 0; size_t prev_offset = 0;
size_t res_offset = 0; for (size_t i = 0; i < size; ++i)
for (size_t i = 0; i < offsets.size(); ++i)
{ {
const char * haystack = reinterpret_cast<const char *>(&data[prev_offset]); res[i] = detectTonality(data.data() + prev_offset, offsets[i] - 1 - prev_offset, emotional_dict);
String str = haystack; prev_offset = offsets[i];
String buf;
Float64 weight = 0;
Float64 count_words = 0;
String answer;
String word;
/// Select all Russian words from the string
for (size_t ind = 0; ind < str.size();)
{
if (!isASCII(str[ind]))
{
word.push_back(str[ind]);
++ind;
while ((ind < str.size()) && (!isASCII(str[ind])))
{
word.push_back(str[ind]);
++ind;
}
/// Try to find a russian word in the tonality dictionary
auto it = emotional_dict.find(word);
if (it != emotional_dict.end())
{
count_words += 1;
weight += it->getMapped();
}
word = "";
}
else
{
++ind;
}
}
/// Calculate average value of tonality
Float64 total_tonality = weight / count_words;
buf = get_tonality(total_tonality);
const auto res = buf.c_str();
size_t cur_offset = offsets[i];
size_t ans_size = strlen(res);
res_data.resize(res_offset + ans_size + 1);
memcpy(&res_data[res_offset], res, ans_size);
res_offset += ans_size;
res_data[res_offset] = 0;
++res_offset;
res_offsets[i] = res_offset;
prev_offset = cur_offset;
} }
} }
static void vectorFixedToConstant(const ColumnString::Chars & /*data*/, size_t /*n*/, Float32 & /*res*/) {}
static void vectorFixedToVector(const ColumnString::Chars & data, size_t n, PaddedPODArray<Float32> & res)
{
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
size_t size = data.size() / n;
for (size_t i = 0; i < size; ++i)
res[i] = detectTonality(data.data() + i * n, n, emotional_dict);
}
[[noreturn]] static void array(const ColumnString::Offsets &, PaddedPODArray<Float32> &)
{
throw Exception("Cannot apply function detectTonality to Array argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
[[noreturn]] static void uuid(const ColumnUUID::Container &, size_t &, PaddedPODArray<Float32> &)
{
throw Exception("Cannot apply function detectTonality to UUID argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}; };
struct NameGetTonality struct NameDetectTonality
{ {
static constexpr auto name = "detectTonality"; static constexpr auto name = "detectTonality";
}; };
using FunctionDetectTonality = FunctionStringOrArrayToT<TonalityClassificationImpl, NameDetectTonality, Float32>;
using FunctionGetTonality = FunctionsTextClassification<TonalityClassificationImpl, NameGetTonality>;
void registerFunctionsTonalityClassification(FunctionFactory & factory) void registerFunctionsTonalityClassification(FunctionFactory & factory)
{ {
factory.registerFunction<FunctionGetTonality>(); factory.registerFunction<FunctionDetectTonality>();
} }
} }

View File

@ -2,9 +2,9 @@ if (ENABLE_TESTS)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
# if (ENABLE_EXAMPLES) if (ENABLE_EXAMPLES)
add_subdirectory(examples) add_subdirectory(examples)
# endif() endif()
if (ENABLE_FUZZING) if (ENABLE_FUZZING)
add_subdirectory(fuzzers) add_subdirectory(fuzzers)

View File

@ -3,10 +3,10 @@ en
fr fr
ja ja
zh zh
{'ja':62,'fr':36,'un':0} {'ja':0.62,'fr':0.36,'un':0}
ISO-8859-1 ISO-8859-1
English en
POS 0.465
NEG -0.57647777
POS 0.050505556
C++ C++

View File

@ -11,8 +11,8 @@ SELECT detectLanguageMixed('二兎を追う者は一兎をも得ず二兎を追
SELECT detectCharset('Plain English'); SELECT detectCharset('Plain English');
SELECT detectLanguageUnknown('Plain English'); SELECT detectLanguageUnknown('Plain English');
SELECT detectTonality('Милая кошка'); SELECT detectTonality('милая кошка');
SELECT detectTonality('Злой человек'); SELECT detectTonality('ненависть к людям');
SELECT detectTonality('Обычная прогулка по ближайшему парку'); SELECT detectTonality('обычная прогулка по ближайшему парку');
SELECT detectProgrammingLanguage('#include <iostream>'); SELECT detectProgrammingLanguage('#include <iostream>');