ClickHouse/src/Functions/FunctionsTextClassification.cpp

315 lines
12 KiB
C++
Raw Normal View History

2021-02-07 18:40:55 +00:00
#include <Functions/FunctionsTextClassification.h>
2021-03-18 14:05:28 +00:00
#include "FrequencyHolder.h"
2021-02-07 18:40:55 +00:00
#include <Functions/FunctionFactory.h>
#include <Common/UTF8Helpers.h>
#include <algorithm>
#include <cstring>
2021-03-18 14:05:28 +00:00
#include <cmath>
2021-02-07 18:40:55 +00:00
#include <limits>
2021-03-18 14:05:28 +00:00
#include <unordered_map>
2021-02-07 18:40:55 +00:00
#include <memory>
#include <utility>
2021-03-18 14:05:28 +00:00
#include <sstream>
#include <set>
2021-02-07 18:40:55 +00:00
namespace DB
{
2021-03-18 14:05:28 +00:00
/*
struct TextClassificationDictionaries
{
const std::unordered_map<std::string, std::vector<double>> emotional_dict;
const std::unordered_map<std::string, std::unordered_map<UInt16, double>> encodings_frequency;
const std::string path;
TextClassificationDictionaries()
: emotional_dict(FrequencyHolder::getInstance().getEmotionalDict()),
encodings_frequency(FrequencyHolder::getInstance().getEncodingsFrequency()),
path(FrequencyHolder::getInstance().get_path())
{
}
};
*/
// static std::unordered_map<std::string, std::vector<double>> emotional_dict = classification_dictionaries.getEncodingsFrequency();
// static std::unordered_map<std::string, std::unordered_map<UInt16, double>> encodings_freq = classification_dictionaries.getEmotionalDict();
template <size_t N, bool Emo>
2021-02-07 18:40:55 +00:00
struct TextClassificationImpl
{
2021-03-18 14:05:28 +00:00
using ResultType = std::string;
2021-02-07 18:40:55 +00:00
using CodePoint = UInt8;
/// map_size for ngram count.
static constexpr size_t map_size = 1u << 16;
/// If the data size is bigger than this, behaviour is unspecified for this function.
static constexpr size_t max_string_size = 1u << 15;
/// Default padding to read safely.
static constexpr size_t default_padding = 16;
/// Max codepoints to store at once. 16 is for batching usage and PODArray has this padding.
static constexpr size_t simultaneously_codepoints_num = default_padding + N - 1;
/** map_size of this fits mostly in L2 cache all the time.
* Actually use UInt16 as addings and subtractions do not UB overflow. But think of it as a signed
* integer array.
*/
using NgramCount = UInt16;
2021-03-18 14:05:28 +00:00
static double L2_distance(std::unordered_map<UInt16, double> standart, std::unordered_map<UInt16, double> model)
{
double res = 0;
for (auto& el : standart) {
if (model.find(el.first) != model.end()) {
res += ((model[el.first] - el.second) * (model[el.first] - el.second));
}
}
return res;
}
static double Naive_bayes(std::unordered_map<UInt16, double> standart, std::unordered_map<UInt16, double> model)
{
double res = 1;
for (auto & el : model) {
if (standart[el.first] != 0) {
res += el.second * log(standart[el.first]);
}
}
return res;
}
2021-02-07 18:40:55 +00:00
static ALWAYS_INLINE size_t readCodePoints(CodePoint * code_points, const char *& pos, const char * end)
{
constexpr size_t padding_offset = default_padding - N + 1;
memcpy(code_points, code_points + padding_offset, roundUpToPowerOfTwoOrZero(N - 1) * sizeof(CodePoint));
memcpy(code_points + (N - 1), pos, default_padding * sizeof(CodePoint));
pos += padding_offset;
if (pos > end)
return default_padding - (pos - end);
return default_padding;
}
2021-03-18 14:05:28 +00:00
2021-02-07 18:40:55 +00:00
static ALWAYS_INLINE inline size_t calculateStats(
const char * data,
const size_t size,
NgramCount * ngram_stats,
size_t (*read_code_points)(CodePoint *, const char *&, const char *),
NgramCount * ngram_storage)
{
const char * start = data;
const char * end = data + size;
CodePoint cp[simultaneously_codepoints_num] = {};
/// read_code_points returns the position of cp where it stopped reading codepoints.
size_t found = read_code_points(cp, start, end);
/// We need to start for the first time here, because first N - 1 codepoints mean nothing.
size_t i = N - 1;
size_t len = 0;
do
{
for (; i + N <= found; ++i)
{
2021-02-08 12:23:51 +00:00
UInt32 hash = 0;
2021-02-07 18:40:55 +00:00
for (size_t j = 0; j < N; ++j) {
hash <<= 8;
hash += *(cp + i + j);
}
if (ngram_stats[hash] == 0) {
ngram_storage[len] = hash;
++len;
}
++ngram_stats[hash];
}
i = 0;
} while (start < end && (found = read_code_points(cp, start, end)));
return len;
}
2021-03-18 14:05:28 +00:00
static void word_processing(std::string & word)
2021-02-07 18:40:55 +00:00
{
2021-03-18 14:05:28 +00:00
std::set<char> to_skip {',', '.', '!', '?', ')', '(', '\"', '\'', '[', ']', '{', '}', ':', ';'};
while (to_skip.find(word.back()) != to_skip.end())
{
word.pop_back();
}
while (to_skip.find(word.front()) != to_skip.end())
{
word.erase(0, 1);
}
}
static void constant(std::string data, std::string & res)
{
static std::unordered_map<std::string, std::vector<double>> emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
static std::unordered_map<std::string, std::unordered_map<UInt16, double>> encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency();
/*
static TextClassificationDictionaries classification_dictionaries;
static std::unordered_map<std::string, std::vector<double>> emotional_dict = classification_dictionaries.emotional_dict;
static std::unordered_map<std::string, std::unordered_map<UInt16, double>> encodings_freq = classification_dictionaries.encodings_frequency;
*/
if (!Emo)
{
std::unique_ptr<NgramCount[]> common_stats{new NgramCount[map_size]{}}; // frequency of N-grams
std::unique_ptr<NgramCount[]> ngram_storage{new NgramCount[map_size]{}}; // list of N-grams
size_t len = calculateStats(data.data(), data.size(), common_stats.get(), readCodePoints, ngram_storage.get()); // count of N-grams
std::string ans;
double count_bigram = data.size() - 1;
std::unordered_map<UInt16, double> model;
for (size_t i = 0; i < len; ++i) {
ans += std::to_string(ngram_storage.get()[i]) + " " + std::to_string(static_cast<double>(common_stats.get()[ngram_storage.get()[i]]) / count_bigram) + "\n";
model[ngram_storage.get()[i]] = static_cast<double>(common_stats.get()[ngram_storage.get()[i]]) / count_bigram;
}
double res1 = L2_distance(encodings_freq["freq_CP866"], model);
double res2 = L2_distance(encodings_freq["freq_ISO"], model);
double res3 = L2_distance(encodings_freq["freq_WINDOWS-1251"], model);
double res4 = L2_distance(encodings_freq["freq_UTF-8"], model);
ans += std::to_string(res1) + " " + std::to_string(res2) + " " + std::to_string(res3) + " " + std::to_string(res4) + "\n";
res = ans;
}
else
{
double freq = 0;
double count_words = 0;
std::string ans;
std::stringstream ss;
ss << data;
std::string to_check;
while (ss >> to_check)
{
word_processing(to_check);
if (emotional_dict.find(to_check) != emotional_dict.cend())
{
count_words += 1;
ans += to_check + " " + std::to_string(emotional_dict[to_check][0]) + "\n";
freq += emotional_dict[to_check][0];
}
}
double total_tonality = freq / count_words;
if (total_tonality < 0.5)
{
ans += "NEG";
}
else if (total_tonality > 1)
{
ans += "POS";
}
else
{
ans += "NEUT";
}
ans += " " + std::to_string(total_tonality) + "\n";
res = ans;
}
2021-02-07 18:40:55 +00:00
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
2021-03-18 14:05:28 +00:00
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
2021-02-07 18:40:55 +00:00
{
2021-03-18 14:05:28 +00:00
static std::unordered_map<std::string, std::unordered_map<UInt16, double>> encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency();
/*
static TextClassificationDictionaries classification_dictionaries;
static std::unordered_map<std::string, std::vector<double>> emotional_dict = classification_dictionaries.emotional_dict;
static std::unordered_map<std::string, std::unordered_map<UInt16, double>> encodings_freq = classification_dictionaries.encodings_frequency;
*/
res_data.reserve(1024);
res_offsets.resize(offsets.size());
2021-02-07 18:40:55 +00:00
size_t prev_offset = 0;
2021-03-18 14:05:28 +00:00
size_t res_offset = 0;
2021-02-07 18:40:55 +00:00
2021-03-18 14:05:28 +00:00
for (size_t i = 0; i < offsets.size(); ++i)
2021-02-07 18:40:55 +00:00
{
const char * haystack = reinterpret_cast<const char *>(&data[prev_offset]);
2021-02-07 19:46:33 +00:00
std::string str = haystack;
std::unique_ptr<NgramCount[]> common_stats{new NgramCount[map_size]{}}; // frequency of N-grams
std::unique_ptr<NgramCount[]> ngram_storage{new NgramCount[map_size]{}}; // list of N-grams
2021-03-18 14:05:28 +00:00
size_t len = calculateStats(str.data(), str.size(), common_stats.get(), readCodePoints, ngram_storage.get()); // count of N-grams
std::string prom;
double count_bigram = data.size() - 1;
std::unordered_map<UInt16, double> model1;
std::unordered_map<UInt16, double> model2;
for (size_t j = 0; j < len; ++j)
{
model2[ngram_storage.get()[j]] = static_cast<double>(common_stats.get()[ngram_storage.get()[j]]);
}
for (size_t j = 0; j < len; ++j)
{
model1[ngram_storage.get()[j]] = static_cast<double>(common_stats.get()[ngram_storage.get()[j]]) / count_bigram;
}
double res1 = L2_distance(encodings_freq["freq_CP866"], model1);
double res2 = L2_distance(encodings_freq["freq_ISO"], model1);
double res3 = L2_distance(encodings_freq["freq_WINDOWS-1251"], model1);
double res4 = L2_distance(encodings_freq["freq_UTF-8"], model1);
prom += std::to_string(res1) + " " + std::to_string(res2) + " " + std::to_string(res3) + " " + std::to_string(res4) + "\n";
double res12 = Naive_bayes(encodings_freq["freq_CP866"], model2);
double res22 = Naive_bayes(encodings_freq["freq_ISO"], model2);
double res32 = Naive_bayes(encodings_freq["freq_WINDOWS-1251"], model2);
double res42 = Naive_bayes(encodings_freq["freq_UTF-8"], model2);
prom += std::to_string(res12) + " " + std::to_string(res22) + " " + std::to_string(res32) + " " + std::to_string(res42) + "\n";
const auto ans = prom.c_str();
size_t cur_offset = offsets[i];
res_data.resize(res_offset + strlen(ans) + 1);
memcpy(&res_data[res_offset], ans, strlen(ans));
res_offset += strlen(ans);
res_data[res_offset] = 0;
++res_offset;
res_offsets[i] = res_offset;
prev_offset = cur_offset;
2021-02-07 18:40:55 +00:00
}
}
2021-03-18 14:05:28 +00:00
2021-02-07 18:40:55 +00:00
};
struct NameBiGramcount
{
static constexpr auto name = "biGramcount";
};
2021-03-18 14:05:28 +00:00
struct NameGetEmo
{
static constexpr auto name = "getEmo";
};
2021-02-07 19:46:33 +00:00
2021-02-07 18:40:55 +00:00
2021-03-18 14:05:28 +00:00
using FunctionBiGramcount = FunctionsTextClassification<TextClassificationImpl<2, false>, NameBiGramcount>;
using FunctionGetEmo = FunctionsTextClassification<TextClassificationImpl<2, true>, NameGetEmo>;
2021-02-07 18:40:55 +00:00
void registerFunctionsTextClassification(FunctionFactory & factory)
{
factory.registerFunction<FunctionBiGramcount>();
2021-03-18 14:05:28 +00:00
factory.registerFunction<FunctionGetEmo>();
2021-02-07 18:40:55 +00:00
}
}