This commit is contained in:
Nikolay Degterinsky 2022-01-10 15:36:32 +00:00
parent fce10091a9
commit 85b8985df2
12 changed files with 241 additions and 507 deletions

View File

@ -36,4 +36,4 @@ endif ()
set (USE_NLP 1)
message (STATUS "Using Libraries for NLP functions: contrib/wordnet-blast, contrib/libstemmer_c, contrib/lemmagen-c")
message (STATUS "Using Libraries for NLP functions: contrib/wordnet-blast, contrib/libstemmer_c, contrib/lemmagen-c, contrib/cld2")

2
contrib/nlp-data vendored

@ -1 +1 @@
Subproject commit 3bc8aef8440b66823186f47a74996cbdad66c04f
Subproject commit 5591f91f5e748cba8fb9ef81564176feae774853

View File

@ -0,0 +1,14 @@
include(${ClickHouse_SOURCE_DIR}/cmake/embed_binary.cmake)
set(LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/nlp-data")
add_library (nlp_data INTERFACE)
clickhouse_embed_binaries(
TARGET nlp_dicts
RESOURCE_DIR "${LIBRARY_DIR}"
RESOURCES charset.zst tonality_ru.zst programming.zst
)
add_dependencies(nlp_data nlp_dicts)
target_link_libraries(nlp_data INTERFACE "-Wl,${WHOLE_ARCHIVE} $<TARGET_FILE:nlp_dicts> -Wl,${NO_WHOLE_ARCHIVE}")

View File

@ -1,22 +1,20 @@
#pragma once
#include <Common/Arena.h>
#include <Common/getResource.h>
#include <Common/HashTable/HashMap.h>
#include <Common/StringUtils/StringUtils.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <IO/readFloatText.h>
#include <IO/Operators.h>
#include <IO/ZstdInflatingReadBuffer.h>
#include <Common/Arena.h>
#include <base/StringRef.h>
#include <Common/HashTable/HashMap.h>
#include <base/logger_useful.h>
#include <string_view>
#include <string>
#include <cstring>
#include <unordered_map>
#include <base/logger_useful.h>
#include <Common/getResource.h>
namespace DB
{
@ -26,6 +24,14 @@ namespace ErrorCodes
extern const int FILE_DOESNT_EXIST;
}
/// FrequencyHolder class is responsible for storing and loading dictionaries
/// needed for text classification functions:
///
/// 1. detectLanguageUnknown
/// 2. detectCharset
/// 3. detectTonality
/// 4. detectProgrammingLanguage
class FrequencyHolder
{
@ -39,6 +45,7 @@ public:
struct Encoding
{
String name;
String lang;
HashMap<UInt16, Float64> map;
};
@ -54,14 +61,13 @@ public:
return instance;
}
void loadEncodingsFrequency()
{
Poco::Logger * log = &Poco::Logger::get("EncodingsFrequency");
LOG_TRACE(log, "Loading embedded charset frequencies");
auto resource = getResource("charset_freq.txt.zst");
auto resource = getResource("charset.zst");
if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded charset frequencies");
@ -71,12 +77,12 @@ public:
String charset_name;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
std::unique_ptr<ReadBuffer> in = std::make_unique<ZstdInflatingReadBuffer>(std::move(buf));
ZstdInflatingReadBuffer in(std::move(buf));
while (!in->eof())
while (!in.eof())
{
readString(line, *in);
++in->position();
readString(line, in);
in.ignore();
if (line.empty())
continue;
@ -84,13 +90,21 @@ public:
ReadBufferFromString buf_line(line);
// Start loading a new charset
if (line.starts_with("//"))
if (line.starts_with("// "))
{
// Skip "// "
buf_line.ignore(3);
readString(charset_name, buf_line);
/* In our dictionary we have lines with form: <Language>_<Charset>
* If we need to find language of data, we return <Language>
* If we need to find charset of data, we return <Charset>.
*/
size_t sep = charset_name.find('_');
Encoding enc;
enc.name = charset_name;
enc.lang = charset_name.substr(0, sep);
enc.name = charset_name.substr(sep + 1);
encodings_freq.push_back(std::move(enc));
}
else
@ -109,9 +123,9 @@ public:
void loadEmotionalDict()
{
Poco::Logger * log = &Poco::Logger::get("EmotionalDict");
LOG_TRACE(log, "Loading embedded emotional dictionary (RU)");
LOG_TRACE(log, "Loading embedded emotional dictionary");
auto resource = getResource("emotional_dictionary_rus.txt.zst");
auto resource = getResource("tonality_ru.zst");
if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded emotional dictionary");
@ -121,12 +135,12 @@ public:
size_t count = 0;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
std::unique_ptr<ReadBuffer> in = std::make_unique<ZstdInflatingReadBuffer>(std::move(buf));
ZstdInflatingReadBuffer in(std::move(buf));
while (!in->eof())
while (!in.eof())
{
readString(line, *in);
++in->position();
readString(line, in);
in.ignore();
if (line.empty())
continue;
@ -151,7 +165,7 @@ public:
LOG_TRACE(log, "Loading embedded programming languages frequencies loading");
auto resource = getResource("prog_freq.txt.zst");
auto resource = getResource("programming.zst");
if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded programming languages frequencies");
@ -161,12 +175,12 @@ public:
String programming_language;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
std::unique_ptr<ReadBuffer> in = std::make_unique<ZstdInflatingReadBuffer>(std::move(buf));
ZstdInflatingReadBuffer in(std::move(buf));
while (!in->eof())
while (!in.eof())
{
readString(line, *in);
++in->position();
readString(line, in);
in.ignore();
if (line.empty())
continue;
@ -174,8 +188,9 @@ public:
ReadBufferFromString buf_line(line);
// Start loading a new language
if (line.starts_with("//"))
if (line.starts_with("// "))
{
// Skip "// "
buf_line.ignore(3);
readString(programming_language, buf_line);

View File

@ -1,18 +1,18 @@
#include <Functions/FunctionsTextClassification.h>
#include <Common/FrequencyHolder.h>
#include <Functions/FunctionFactory.h>
#include <Common/UTF8Helpers.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Functions/FunctionStringToString.h>
#include <cstring>
#include <cmath>
#include <unordered_map>
#include <memory>
#include <utility>
#include <unordered_map>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
}
/* Determine language and charset of text data. For each text, we build the distribution of bigrams bytes.
* Then we use marked-up dictionaries with distributions of bigram bytes of various languages and charsets.
* Using a naive Bayesian classifier, find the most likely charset and language and return it
@ -21,10 +21,6 @@ namespace DB
template <size_t N, bool detect_language>
struct CharsetClassificationImpl
{
using ResultType = String;
using CodePoint = UInt8;
/* We need to solve zero-frequency problem for Naive Bayes Classifier
* If the bigram is not found in the text, we assume that the probability of its meeting is 1e-06.
* 1e-06 is minimal value in our marked-up dictionary.
@ -34,28 +30,22 @@ struct CharsetClassificationImpl
/// If the data size is bigger than this, behaviour is unspecified for this function.
static constexpr size_t max_string_size = 1u << 15;
/// Default padding to read safely.
static constexpr size_t default_padding = 16;
/// Max codepoints to store at once. 16 is for batching usage and PODArray has this padding.
static constexpr size_t simultaneously_codepoints_num = default_padding + N - 1;
static ALWAYS_INLINE inline Float64 naiveBayes(const FrequencyHolder::EncodingMap & standard,
std::unordered_map<UInt16, Float64> & model,
static ALWAYS_INLINE inline Float64 naiveBayes(
const FrequencyHolder::EncodingMap & standard,
const HashMap<UInt16, Float64> & model,
Float64 max_result)
{
Float64 res = 0;
for (auto & el : model)
for (const auto & el : model)
{
/// Try to find bigram in the dictionary.
auto it = standard.find(el.first);
const auto * it = standard.find(el.getKey());
if (it != standard.end())
{
res += el.second * log(it->getMapped());
res += el.getMapped() * log(it->getMapped());
} else
{
res += el.second * log(zero_frequency);
res += el.getMapped() * log(zero_frequency);
}
/// If at some step the result has become less than the current maximum, then it makes no sense to count it fully.
if (res < max_result)
@ -66,95 +56,21 @@ struct CharsetClassificationImpl
return res;
}
static ALWAYS_INLINE size_t readCodePoints(CodePoint * code_points, const char *& pos, const char * end)
{
constexpr size_t padding_offset = default_padding - N + 1;
memcpy(code_points, code_points + padding_offset, roundUpToPowerOfTwoOrZero(N - 1) * sizeof(CodePoint));
memcpy(code_points + (N - 1), pos, default_padding * sizeof(CodePoint));
pos += padding_offset;
if (pos > end)
return default_padding - (pos - end);
return default_padding;
}
/// Сount how many times each bigram occurs in the text.
static ALWAYS_INLINE inline size_t calculateStats(
const char * data,
static ALWAYS_INLINE inline void calculateStats(
const UInt8 * data,
const size_t size,
size_t (*read_code_points)(CodePoint *, const char *&, const char *),
std::unordered_map<UInt16, Float64>& model)
HashMap<UInt16, Float64> & model)
{
const char * start = data;
const char * end = data + size;
CodePoint cp[simultaneously_codepoints_num] = {};
/// read_code_points returns the position of cp where it stopped reading codepoints.
size_t found = read_code_points(cp, start, end);
/// We need to start for the first time here, because first N - 1 codepoints mean nothing.
size_t i = N - 1;
size_t len = 0;
do
{
for (; i + N <= found; ++i)
{
UInt32 hash = 0;
for (size_t j = 0; j < N; ++j)
UInt16 hash = 0;
for (size_t i = 0; i < size; ++i)
{
hash <<= 8;
hash += *(cp + i + j);
}
if (model[hash] == 0)
{
model[hash] = 1;
++len;
}
hash += *(data + i);
++model[hash];
}
i = 0;
} while (start < end && (found = read_code_points(cp, start, end)));
return len;
}
static void constant(String data, String & res)
{
const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency();
std::unordered_map<UInt16, Float64> model;
calculateStats(data.data(), data.size(), readCodePoints, model);
Float64 max_result = log(zero_frequency) * (max_string_size);
String poss_ans;
/// Go through the dictionary and find the charset with the highest weight
for (auto& item : encodings_freq)
{
Float64 score = naiveBayes(item.map, model, max_result);
if (max_result < score)
{
poss_ans = item.name;
max_result = score;
}
}
/* In our dictionary we have lines with form: <Language>_<Charset>
* If we need to find language of data, we return <Language>
* If we need to find charset of data, we return <Charset>.
*/
size_t sep = poss_ans.find('_');
if (detect_language)
{
res = poss_ans.erase(0, sep + 1);
}
else
{
res = poss_ans.erase(sep, poss_ans.size() - sep);
}
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
@ -163,64 +79,53 @@ struct CharsetClassificationImpl
{
const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency();
res_data.reserve(1024);
if (detect_language)
/// 2 chars for ISO code + 1 zero byte
res_data.reserve(offsets.size() * 3);
else
/// Mean charset length is 8
res_data.reserve(offsets.size() * 8);
res_offsets.resize(offsets.size());
size_t prev_offset = 0;
size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
const char * haystack = reinterpret_cast<const char *>(&data[prev_offset]);
String str = haystack;
const UInt8 * str = data.data() + offsets[i - 1];
const size_t str_len = offsets[i] - offsets[i - 1] - 1;
String poss_ans;
std::string_view res;
std::unordered_map<UInt16, Float64> model;
calculateStats(str.data(), str.size(), readCodePoints, model);
HashMap<UInt16, Float64> model;
calculateStats(str, str_len, model);
/// Go through the dictionary and find the charset with the highest weight
Float64 max_result = log(zero_frequency) * (max_string_size);
for (auto& item : encodings_freq)
for (const auto & item : encodings_freq)
{
Float64 score = naiveBayes(item.map, model, max_result);
if (max_result < score)
{
max_result = score;
poss_ans = item.name;
res = detect_language ? item.lang : item.name;
}
}
size_t sep = poss_ans.find('_');
String ans_str;
res_data.resize(res_offset + res.size() + 1);
memcpy(&res_data[res_offset], res.data(), res.size());
if (detect_language)
{
ans_str = poss_ans.erase(0, sep + 1);
}
else
{
ans_str = poss_ans.erase(sep, poss_ans.size() - sep);
}
ans_str = poss_ans;
const auto res = ans_str.c_str();
size_t cur_offset = offsets[i];
size_t ans_size = strlen(res);
res_data.resize(res_offset + ans_size + 1);
memcpy(&res_data[res_offset], res, ans_size);
res_offset += ans_size;
res_data[res_offset] = 0;
++res_offset;
res_data[res_offset + res.size()] = 0;
res_offset += res.size() + 1;
res_offsets[i] = res_offset;
prev_offset = cur_offset;
}
}
[[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &)
{
throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN);
}
};
@ -235,8 +140,8 @@ struct NameLanguageDetect
};
using FunctionCharsetDetect = FunctionsTextClassification<CharsetClassificationImpl<2, true>, NameCharsetDetect>;
using FunctionLanguageDetect = FunctionsTextClassification<CharsetClassificationImpl<2, false>, NameLanguageDetect>;
using FunctionCharsetDetect = FunctionStringToString<CharsetClassificationImpl<2, false>, NameCharsetDetect, false>;
using FunctionLanguageDetect = FunctionStringToString<CharsetClassificationImpl<2, true>, NameLanguageDetect, false>;
void registerFunctionsCharsetClassification(FunctionFactory & factory)
{

View File

@ -4,7 +4,7 @@
#if USE_NLP
#include <Functions/FunctionsTextClassification.h>
#include <Functions/FunctionStringToString.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeMap.h>
@ -34,9 +34,7 @@ extern const int ILLEGAL_COLUMN;
struct LanguageClassificationImpl
{
using ResultType = String;
static String codeISO(std::string_view code_string)
static std::string_view codeISO(std::string_view code_string)
{
if (code_string.ends_with("-Latn"))
code_string.remove_suffix(code_string.size() - 5);
@ -61,54 +59,44 @@ struct LanguageClassificationImpl
if (code_string.size() != 2)
return "other";
return String(code_string);
return code_string;
}
static void constant(const String & data, String & res)
{
bool is_reliable = true;
const char * str = data.c_str();
auto lang = CLD2::DetectLanguage(str, strlen(str), true, &is_reliable);
res = codeISO(LanguageCode(lang));
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
{
res_data.reserve(1024);
/// Constant 3 is based on the fact that in general we need 2 characters for ISO code + 1 zero byte
res_data.reserve(offsets.size() * 3);
res_offsets.resize(offsets.size());
size_t prev_offset = 0;
bool is_reliable = true;
size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
const char * str = reinterpret_cast<const char *>(&data[prev_offset]);
String res;
bool is_reliable = true;
const char * str = reinterpret_cast<const char *>(data.data() + offsets[i - 1]);
const size_t str_len = offsets[i] - offsets[i - 1] - 1;
auto lang = CLD2::DetectLanguage(str, strlen(str), true, &is_reliable);
res = codeISO(LanguageCode(lang));
size_t cur_offset = offsets[i];
auto lang = CLD2::DetectLanguage(str, str_len, true, &is_reliable);
auto res = codeISO(LanguageCode(lang));
res_data.resize(res_offset + res.size() + 1);
memcpy(&res_data[res_offset], res.data(), res.size());
res_offset += res.size();
res_data[res_offset] = 0;
++res_offset;
res_data[res_offset + res.size()] = 0;
res_offset += res.size() + 1;
res_offsets[i] = res_offset;
prev_offset = cur_offset;
}
}
[[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &)
{
throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN);
}
};
class LanguageClassificationMixedDetect : public IFunction
@ -116,6 +104,9 @@ class LanguageClassificationMixedDetect : public IFunction
public:
static constexpr auto name = "detectLanguageMixed";
/// Number of top results
static constexpr auto top_N = 3;
static FunctionPtr create(ContextPtr) { return std::make_shared<LanguageClassificationMixedDetect>(); }
String getName() const override { return name; }
@ -132,7 +123,7 @@ public:
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeMap>(std::make_shared<DataTypeString>(), std::make_shared<DataTypeInt32>());
return std::make_shared<DataTypeMap>(std::make_shared<DataTypeString>(), std::make_shared<DataTypeFloat32>());
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
@ -145,8 +136,8 @@ public:
"Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
auto & input_data = col->getChars();
auto & input_offsets = col->getOffsets();
const auto & input_data = col->getChars();
const auto & input_offsets = col->getOffsets();
/// Create and fill the result map.
@ -158,15 +149,15 @@ public:
MutableColumnPtr values_data = value_type->createColumn();
MutableColumnPtr offsets = DataTypeNumber<IColumn::Offset>().createColumn();
size_t total_elements = input_rows_count * 3;
size_t total_elements = input_rows_count * top_N;
keys_data->reserve(total_elements);
values_data->reserve(total_elements);
offsets->reserve(input_rows_count);
bool is_reliable = true;
CLD2::Language result_lang_top3[3];
int32_t pc[3];
int bytes[3];
CLD2::Language result_lang_top3[top_N];
int32_t pc[top_N];
int bytes[top_N];
IColumn::Offset current_offset = 0;
for (size_t i = 0; i < input_rows_count; ++i)
@ -176,16 +167,16 @@ public:
CLD2::DetectLanguageSummary(str, str_len, true, result_lang_top3, pc, bytes, &is_reliable);
for (size_t j = 0; j < 3; ++j)
for (size_t j = 0; j < top_N; ++j)
{
auto res_str = LanguageClassificationImpl::codeISO(LanguageCode(result_lang_top3[j]));
int32_t res_int = static_cast<int>(pc[j]);
Float32 res_float = static_cast<Float32>(pc[j]) / 100;
keys_data->insertData(res_str.data(), res_str.size());
values_data->insertData(reinterpret_cast<const char *>(&res_int), sizeof(res_int));
values_data->insertData(reinterpret_cast<const char *>(&res_float), sizeof(res_float));
}
current_offset += 3;
current_offset += top_N;
offsets->insert(current_offset);
}
@ -203,7 +194,7 @@ struct NameLanguageUTF8Detect
};
using FunctionLanguageUTF8Detect = FunctionsTextClassification<LanguageClassificationImpl, NameLanguageUTF8Detect>;
using FunctionLanguageUTF8Detect = FunctionStringToString<LanguageClassificationImpl, NameLanguageUTF8Detect, false>;
void registerFunctionLanguageDetectUTF8(FunctionFactory & factory)
{

View File

@ -1,12 +1,18 @@
#include <Functions/FunctionsTextClassification.h>
#include <Common/FrequencyHolder.h>
#include <Functions/FunctionFactory.h>
#include <IO/ReadHelpers.h>
#include <Functions/FunctionStringToString.h>
#include <unordered_map>
#include <string_view>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
}
/**
* Determine the programming language from the source code.
* We calculate all the unigrams and bigrams of commands in the source code.
@ -15,16 +21,16 @@ namespace DB
*/
struct ProgrammingClassificationImpl
{
using ResultType = String;
/// Calculate total weight
static ALWAYS_INLINE inline Float64 stateMachine(const FrequencyHolder::Map & standard, std::unordered_map<String, Float64> & model)
static ALWAYS_INLINE inline Float64 stateMachine(
const FrequencyHolder::Map & standard,
const std::unordered_map<String, Float64> & model)
{
Float64 res = 0;
for (auto & el : model)
for (const auto & el : model)
{
/// Try to find each n-gram in dictionary
auto it = standard.find(el.first);
const auto * it = standard.find(el.first);
if (it != standard.end())
{
res += el.second * it->getMapped();
@ -33,104 +39,44 @@ struct ProgrammingClassificationImpl
return res;
}
static void constant(String data, String & res)
{
auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency();
std::unordered_map<String, Float64> data_freq;
String prev_command;
String command;
/// Select all commands from the string
for (size_t i = 0; i < data.size();)
{
/// Assume that all commands are split by spaces
if (!isspace(data[i]))
{
command.push_back(data[i]);
++i;
while ((i < data.size()) && (!isspace(data[i])))
{
command.push_back(data[i]);
++i;
}
if (prev_command == "")
{
prev_command = command;
}
else
{
data_freq[prev_command + command] += 1;
data_freq[prev_command] += 1;
prev_command = command;
}
command = "";
}
else
{
++i;
}
}
String most_liked;
Float64 max_result = 0;
/// Iterate over all programming languages and find the language with the highest weight
for (auto& item : programming_freq)
{
Float64 result = stateMachine(item.map, data_freq);
if (result > max_result)
{
max_result = result;
most_liked = item.name;
}
}
/// If all weights are zero, then we assume that the language is undefined
if (most_liked == "")
{
most_liked = "Undefined";
}
res = most_liked;
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
{
auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency();
const auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency();
res_data.reserve(1024);
/// Constant 5 is arbitrary
res_data.reserve(offsets.size() * 5);
res_offsets.resize(offsets.size());
size_t prev_offset = 0;
size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
const char * haystack = reinterpret_cast<const char *>(&data[prev_offset]);
const UInt8 * str = data.data() + offsets[i - 1];
const size_t str_len = offsets[i] - offsets[i - 1] - 1;
std::unordered_map<String, Float64> data_freq;
String str_data = haystack;
String prev_command;
String command;
/// Select all commands from the string
for (size_t ind = 0; ind < str_data.size();)
for (size_t ind = 0; ind < str_len;)
{
/// Assume that all commands are split by spaces
if (!isspace(str_data[ind]))
if (!isspace(str[ind]))
{
command.push_back(str_data[ind]);
command.push_back(str[ind]);
++ind;
while ((ind < str_data.size()) && (!isspace(str_data[ind])))
while ((ind < str_len) && (!isspace(str[ind])))
{
command.push_back(str_data[ind]);
command.push_back(str[ind]);
++ind;
}
if (prev_command == "")
if (prev_command.empty())
{
prev_command = command;
}
@ -148,39 +94,36 @@ struct ProgrammingClassificationImpl
}
}
String most_liked;
String res;
Float64 max_result = 0;
/// Iterate over all programming languages and find the language with the highest weight
for (auto& item : programming_freq)
for (const auto & item : programming_freq)
{
Float64 result = stateMachine(item.map, data_freq);
if (result > max_result)
{
max_result = result;
most_liked = item.name;
res = item.name;
}
}
/// If all weights are zero, then we assume that the language is undefined
if (most_liked == "")
{
most_liked = "Undefined";
}
if (res.empty())
res = "Undefined";
const auto res = most_liked.c_str();
size_t cur_offset = offsets[i];
size_t ans_size = strlen(res);
res_data.resize(res_offset + ans_size + 1);
memcpy(&res_data[res_offset], res, ans_size);
res_offset += ans_size;
res_data.resize(res_offset + res.size() + 1);
memcpy(&res_data[res_offset], res.data(), res.size());
res_data[res_offset] = 0;
++res_offset;
res_data[res_offset + res.size()] = 0;
res_offset += res.size() + 1;
res_offsets[i] = res_offset;
prev_offset = cur_offset;
}
}
[[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &)
{
throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN);
}
};
struct NameGetProgramming
@ -189,7 +132,7 @@ struct NameGetProgramming
};
using FunctionGetProgramming = FunctionsTextClassification<ProgrammingClassificationImpl, NameGetProgramming>;
using FunctionGetProgramming = FunctionStringToString<ProgrammingClassificationImpl, NameGetProgramming, false>;
void registerFunctionsProgrammingClassification(FunctionFactory & factory)
{

View File

@ -1,86 +0,0 @@
#pragma once
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
/** Functions for text classification:
*
* detectCharset(string data) - detect charset of data.
* Returns string name of most likely charset.
*
* detectLanguage(string data) - detect language of data in various encodings (not UTF-8)
*
* getTonality(string data) - defines the emotional coloring of the text.
* Returns NEG if text is negative, POS if text is positive or NEUT if text is neutral.
*
* getProgrammingLanguage(string data) - detect programming language
*/
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
}
template <typename Impl, typename Name>
class FunctionsTextClassification : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionsTextClassification>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return arguments[0];
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override
{
using ResultType = typename Impl::ResultType;
const ColumnPtr & column = arguments[0].column;
const ColumnConst * col_const = typeid_cast<const ColumnConst *>(&*column);
if (col_const)
{
ResultType res;
Impl::constant(col_const->getValue<String>(), res);
return result_type->createColumnConst(col_const->size(), toField(res));
}
if (const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()))
{
auto col_res = ColumnString::create();
ColumnString::Chars & vec_res = col_res->getChars();
ColumnString::Offsets & offsets_res = col_res->getOffsets();
Impl::vector(col->getChars(), col->getOffsets(), vec_res, offsets_res);
return col_res;
}
else
{
throw Exception(
"Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
}
};
}

View File

@ -1,163 +1,115 @@
#include <Functions/FunctionsTextClassification.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/FrequencyHolder.h>
#include <Common/StringUtils/StringUtils.h>
#include <Functions/FunctionFactory.h>
#include <Common/UTF8Helpers.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Functions/FunctionStringOrArrayToT.h>
#include <unordered_map>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/**
* Determines the sentiment of text data.
* Uses a marked-up sentiment dictionary, each word has a tonality ranging from -3 to 3.
* Uses a marked-up sentiment dictionary, each word has a tonality ranging from -12 to 6.
* For each text, calculate the average sentiment value of its words and return NEG, POS or NEUT
*/
struct TonalityClassificationImpl
{
using ResultType = String;
static String get_tonality(const Float64 & tonality_level)
static Float32 detectTonality(const UInt8 * str, const size_t str_len, const FrequencyHolder::Map & emotional_dict)
{
if (tonality_level < 0.15) { return "NEG"; }
if (tonality_level > 0.45) { return "POS"; }
return "NEUT";
}
static void constant(String data, String & res)
{
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
Float64 weight = 0;
Float64 count_words = 0;
UInt64 count_words = 0;
String answer;
String word;
/// Select all Russian words from the string
for (size_t i = 0; i < data.size();)
{
/// Assume that all non-Ascii characters are Russian letters
if (!isASCII(data[i]))
{
word.push_back(data[i]);
++i;
while ((i < data.size()) && (!isASCII(data[i])))
{
word.push_back(data[i]);
++i;
}
/// Try to find a russian word in the tonality dictionary
auto it = emotional_dict.find(word);
if (it != emotional_dict.end())
{
count_words += 1;
weight += it->getMapped();
}
word = "";
}
else
{
++i;
}
}
/// Calculate average value of tonality
Float64 total_tonality = weight / count_words;
res += get_tonality(total_tonality);
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
{
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
res_data.reserve(1024);
res_offsets.resize(offsets.size());
size_t prev_offset = 0;
size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
const char * haystack = reinterpret_cast<const char *>(&data[prev_offset]);
String str = haystack;
String buf;
Float64 weight = 0;
Float64 count_words = 0;
String answer;
String word;
/// Select all Russian words from the string
for (size_t ind = 0; ind < str.size();)
for (size_t ind = 0; ind < str_len;)
{
/// Assume that all non-ASCII characters are Russian letters
if (!isASCII(str[ind]))
{
word.push_back(str[ind]);
++ind;
while ((ind < str.size()) && (!isASCII(str[ind])))
while ((ind < str_len) && (!isASCII(str[ind])))
{
word.push_back(str[ind]);
++ind;
}
/// Try to find a russian word in the tonality dictionary
auto it = emotional_dict.find(word);
const auto * it = emotional_dict.find(word);
if (it != emotional_dict.end())
{
count_words += 1;
weight += it->getMapped();
}
word = "";
word.clear();
}
else
{
++ind;
}
}
/// Calculate average value of tonality
Float64 total_tonality = weight / count_words;
buf = get_tonality(total_tonality);
/// Calculate average value of tonality.
/// Convert values -12..6 to -1..1
return std::max(weight / count_words / 6, -1.0);
}
const auto res = buf.c_str();
size_t cur_offset = offsets[i];
size_t ans_size = strlen(res);
res_data.resize(res_offset + ans_size + 1);
memcpy(&res_data[res_offset], res, ans_size);
res_offset += ans_size;
/// If the function will return constant value for FixedString data type.
static constexpr auto is_fixed_to_constant = false;
res_data[res_offset] = 0;
++res_offset;
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
PaddedPODArray<Float32> & res)
{
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
res_offsets[i] = res_offset;
prev_offset = cur_offset;
size_t size = offsets.size();
size_t prev_offset = 0;
for (size_t i = 0; i < size; ++i)
{
res[i] = detectTonality(data.data() + prev_offset, offsets[i] - 1 - prev_offset, emotional_dict);
prev_offset = offsets[i];
}
}
static void vectorFixedToConstant(const ColumnString::Chars & /*data*/, size_t /*n*/, Float32 & /*res*/) {}
static void vectorFixedToVector(const ColumnString::Chars & data, size_t n, PaddedPODArray<Float32> & res)
{
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
size_t size = data.size() / n;
for (size_t i = 0; i < size; ++i)
res[i] = detectTonality(data.data() + i * n, n, emotional_dict);
}
[[noreturn]] static void array(const ColumnString::Offsets &, PaddedPODArray<Float32> &)
{
throw Exception("Cannot apply function detectTonality to Array argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
[[noreturn]] static void uuid(const ColumnUUID::Container &, size_t &, PaddedPODArray<Float32> &)
{
throw Exception("Cannot apply function detectTonality to UUID argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
};
struct NameGetTonality
struct NameDetectTonality
{
static constexpr auto name = "detectTonality";
};
using FunctionGetTonality = FunctionsTextClassification<TonalityClassificationImpl, NameGetTonality>;
using FunctionDetectTonality = FunctionStringOrArrayToT<TonalityClassificationImpl, NameDetectTonality, Float32>;
void registerFunctionsTonalityClassification(FunctionFactory & factory)
{
factory.registerFunction<FunctionGetTonality>();
factory.registerFunction<FunctionDetectTonality>();
}
}

View File

@ -2,9 +2,9 @@ if (ENABLE_TESTS)
add_subdirectory(tests)
endif()
# if (ENABLE_EXAMPLES)
if (ENABLE_EXAMPLES)
add_subdirectory(examples)
# endif()
endif()
if (ENABLE_FUZZING)
add_subdirectory(fuzzers)

View File

@ -3,10 +3,10 @@ en
fr
ja
zh
{'ja':62,'fr':36,'un':0}
{'ja':0.62,'fr':0.36,'un':0}
ISO-8859-1
English
POS
NEG
POS
en
0.465
-0.57647777
0.050505556
C++

View File

@ -11,8 +11,8 @@ SELECT detectLanguageMixed('二兎を追う者は一兎をも得ず二兎を追
SELECT detectCharset('Plain English');
SELECT detectLanguageUnknown('Plain English');
SELECT detectTonality('Милая кошка');
SELECT detectTonality('Злой человек');
SELECT detectTonality('Обычная прогулка по ближайшему парку');
SELECT detectTonality('милая кошка');
SELECT detectTonality('ненависть к людям');
SELECT detectTonality('обычная прогулка по ближайшему парку');
SELECT detectProgrammingLanguage('#include <iostream>');