From 85b8985df20c5fcb6a39b19582e58d3449406e0b Mon Sep 17 00:00:00 2001 From: Nikolay Degterinsky Date: Mon, 10 Jan 2022 15:36:32 +0000 Subject: [PATCH] Better --- cmake/find/nlp.cmake | 2 +- contrib/nlp-data | 2 +- contrib/nlp-data-cmake/CMakeLists.txt | 14 ++ src/Common/FrequencyHolder.h | 69 ++++--- .../FunctionsCharsetClassification.cpp | 193 +++++------------- .../FunctionsLanguageClassification.cpp | 71 +++---- .../FunctionsProgrammingClassification.cpp | 135 ++++-------- src/Functions/FunctionsTextClassification.h | 86 -------- .../FunctionsTonalityClassification.cpp | 156 +++++--------- src/Interpreters/CMakeLists.txt | 4 +- .../02133_classification.reference | 10 +- .../0_stateless/02133_classification.sql | 6 +- 12 files changed, 241 insertions(+), 507 deletions(-) create mode 100644 contrib/nlp-data-cmake/CMakeLists.txt delete mode 100644 src/Functions/FunctionsTextClassification.h diff --git a/cmake/find/nlp.cmake b/cmake/find/nlp.cmake index 29f97039e06..4b9311c6685 100644 --- a/cmake/find/nlp.cmake +++ b/cmake/find/nlp.cmake @@ -36,4 +36,4 @@ endif () set (USE_NLP 1) -message (STATUS "Using Libraries for NLP functions: contrib/wordnet-blast, contrib/libstemmer_c, contrib/lemmagen-c") +message (STATUS "Using Libraries for NLP functions: contrib/wordnet-blast, contrib/libstemmer_c, contrib/lemmagen-c, contrib/cld2") diff --git a/contrib/nlp-data b/contrib/nlp-data index 3bc8aef8440..5591f91f5e7 160000 --- a/contrib/nlp-data +++ b/contrib/nlp-data @@ -1 +1 @@ -Subproject commit 3bc8aef8440b66823186f47a74996cbdad66c04f +Subproject commit 5591f91f5e748cba8fb9ef81564176feae774853 diff --git a/contrib/nlp-data-cmake/CMakeLists.txt b/contrib/nlp-data-cmake/CMakeLists.txt new file mode 100644 index 00000000000..d13258725d5 --- /dev/null +++ b/contrib/nlp-data-cmake/CMakeLists.txt @@ -0,0 +1,14 @@ +include(${ClickHouse_SOURCE_DIR}/cmake/embed_binary.cmake) + +set(LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/nlp-data") + +add_library (nlp_data INTERFACE) + +clickhouse_embed_binaries( + TARGET nlp_dicts + RESOURCE_DIR "${LIBRARY_DIR}" + RESOURCES charset.zst tonality_ru.zst programming.zst +) + +add_dependencies(nlp_data nlp_dicts) +target_link_libraries(nlp_data INTERFACE "-Wl,${WHOLE_ARCHIVE} $ -Wl,${NO_WHOLE_ARCHIVE}") diff --git a/src/Common/FrequencyHolder.h b/src/Common/FrequencyHolder.h index 2389c910ed0..a98ae0452d3 100644 --- a/src/Common/FrequencyHolder.h +++ b/src/Common/FrequencyHolder.h @@ -1,22 +1,20 @@ #pragma once + +#include +#include +#include #include #include #include #include #include -#include #include -#include #include -#include +#include #include -#include -#include #include -#include -#include namespace DB { @@ -26,6 +24,14 @@ namespace ErrorCodes extern const int FILE_DOESNT_EXIST; } +/// FrequencyHolder class is responsible for storing and loading dictionaries +/// needed for text classification functions: +/// +/// 1. detectLanguageUnknown +/// 2. detectCharset +/// 3. detectTonality +/// 4. detectProgrammingLanguage + class FrequencyHolder { @@ -39,6 +45,7 @@ public: struct Encoding { String name; + String lang; HashMap map; }; @@ -54,14 +61,13 @@ public: return instance; } - void loadEncodingsFrequency() { Poco::Logger * log = &Poco::Logger::get("EncodingsFrequency"); LOG_TRACE(log, "Loading embedded charset frequencies"); - auto resource = getResource("charset_freq.txt.zst"); + auto resource = getResource("charset.zst"); if (resource.empty()) throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded charset frequencies"); @@ -71,12 +77,12 @@ public: String charset_name; auto buf = std::make_unique(resource.data(), resource.size()); - std::unique_ptr in = std::make_unique(std::move(buf)); + ZstdInflatingReadBuffer in(std::move(buf)); - while (!in->eof()) + while (!in.eof()) { - readString(line, *in); - ++in->position(); + readString(line, in); + in.ignore(); if (line.empty()) continue; @@ -84,13 +90,21 @@ public: ReadBufferFromString buf_line(line); // Start loading a new charset - if (line.starts_with("//")) + if (line.starts_with("// ")) { + // Skip "// " buf_line.ignore(3); readString(charset_name, buf_line); + /* In our dictionary we have lines with form: _ + * If we need to find language of data, we return + * If we need to find charset of data, we return . + */ + size_t sep = charset_name.find('_'); + Encoding enc; - enc.name = charset_name; + enc.lang = charset_name.substr(0, sep); + enc.name = charset_name.substr(sep + 1); encodings_freq.push_back(std::move(enc)); } else @@ -109,9 +123,9 @@ public: void loadEmotionalDict() { Poco::Logger * log = &Poco::Logger::get("EmotionalDict"); - LOG_TRACE(log, "Loading embedded emotional dictionary (RU)"); + LOG_TRACE(log, "Loading embedded emotional dictionary"); - auto resource = getResource("emotional_dictionary_rus.txt.zst"); + auto resource = getResource("tonality_ru.zst"); if (resource.empty()) throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded emotional dictionary"); @@ -121,12 +135,12 @@ public: size_t count = 0; auto buf = std::make_unique(resource.data(), resource.size()); - std::unique_ptr in = std::make_unique(std::move(buf)); + ZstdInflatingReadBuffer in(std::move(buf)); - while (!in->eof()) + while (!in.eof()) { - readString(line, *in); - ++in->position(); + readString(line, in); + in.ignore(); if (line.empty()) continue; @@ -151,7 +165,7 @@ public: LOG_TRACE(log, "Loading embedded programming languages frequencies loading"); - auto resource = getResource("prog_freq.txt.zst"); + auto resource = getResource("programming.zst"); if (resource.empty()) throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded programming languages frequencies"); @@ -161,12 +175,12 @@ public: String programming_language; auto buf = std::make_unique(resource.data(), resource.size()); - std::unique_ptr in = std::make_unique(std::move(buf)); + ZstdInflatingReadBuffer in(std::move(buf)); - while (!in->eof()) + while (!in.eof()) { - readString(line, *in); - ++in->position(); + readString(line, in); + in.ignore(); if (line.empty()) continue; @@ -174,8 +188,9 @@ public: ReadBufferFromString buf_line(line); // Start loading a new language - if (line.starts_with("//")) + if (line.starts_with("// ")) { + // Skip "// " buf_line.ignore(3); readString(programming_language, buf_line); diff --git a/src/Functions/FunctionsCharsetClassification.cpp b/src/Functions/FunctionsCharsetClassification.cpp index c6f54f2c5d0..2058770ed58 100644 --- a/src/Functions/FunctionsCharsetClassification.cpp +++ b/src/Functions/FunctionsCharsetClassification.cpp @@ -1,18 +1,18 @@ -#include #include #include -#include -#include -#include +#include -#include -#include -#include #include -#include +#include namespace DB { + +namespace ErrorCodes +{ + extern const int ILLEGAL_COLUMN; +} + /* Determine language and charset of text data. For each text, we build the distribution of bigrams bytes. * Then we use marked-up dictionaries with distributions of bigram bytes of various languages ​​and charsets. * Using a naive Bayesian classifier, find the most likely charset and language and return it @@ -21,10 +21,6 @@ namespace DB template struct CharsetClassificationImpl { - - using ResultType = String; - using CodePoint = UInt8; - /* We need to solve zero-frequency problem for Naive Bayes Classifier * If the bigram is not found in the text, we assume that the probability of its meeting is 1e-06. * 1e-06 is minimal value in our marked-up dictionary. @@ -34,28 +30,22 @@ struct CharsetClassificationImpl /// If the data size is bigger than this, behaviour is unspecified for this function. static constexpr size_t max_string_size = 1u << 15; - /// Default padding to read safely. - static constexpr size_t default_padding = 16; - - /// Max codepoints to store at once. 16 is for batching usage and PODArray has this padding. - static constexpr size_t simultaneously_codepoints_num = default_padding + N - 1; - - - static ALWAYS_INLINE inline Float64 naiveBayes(const FrequencyHolder::EncodingMap & standard, - std::unordered_map & model, + static ALWAYS_INLINE inline Float64 naiveBayes( + const FrequencyHolder::EncodingMap & standard, + const HashMap & model, Float64 max_result) { Float64 res = 0; - for (auto & el : model) + for (const auto & el : model) { /// Try to find bigram in the dictionary. - auto it = standard.find(el.first); + const auto * it = standard.find(el.getKey()); if (it != standard.end()) { - res += el.second * log(it->getMapped()); + res += el.getMapped() * log(it->getMapped()); } else { - res += el.second * log(zero_frequency); + res += el.getMapped() * log(zero_frequency); } /// If at some step the result has become less than the current maximum, then it makes no sense to count it fully. if (res < max_result) @@ -66,95 +56,21 @@ struct CharsetClassificationImpl return res; } - - static ALWAYS_INLINE size_t readCodePoints(CodePoint * code_points, const char *& pos, const char * end) - { - constexpr size_t padding_offset = default_padding - N + 1; - memcpy(code_points, code_points + padding_offset, roundUpToPowerOfTwoOrZero(N - 1) * sizeof(CodePoint)); - memcpy(code_points + (N - 1), pos, default_padding * sizeof(CodePoint)); - pos += padding_offset; - if (pos > end) - return default_padding - (pos - end); - return default_padding; - } - /// Сount how many times each bigram occurs in the text. - static ALWAYS_INLINE inline size_t calculateStats( - const char * data, + static ALWAYS_INLINE inline void calculateStats( + const UInt8 * data, const size_t size, - size_t (*read_code_points)(CodePoint *, const char *&, const char *), - std::unordered_map& model) + HashMap & model) { - - const char * start = data; - const char * end = data + size; - CodePoint cp[simultaneously_codepoints_num] = {}; - /// read_code_points returns the position of cp where it stopped reading codepoints. - size_t found = read_code_points(cp, start, end); - /// We need to start for the first time here, because first N - 1 codepoints mean nothing. - size_t i = N - 1; - size_t len = 0; - do + UInt16 hash = 0; + for (size_t i = 0; i < size; ++i) { - for (; i + N <= found; ++i) - { - UInt32 hash = 0; - for (size_t j = 0; j < N; ++j) - { - hash <<= 8; - hash += *(cp + i + j); - } - if (model[hash] == 0) - { - model[hash] = 1; - ++len; - } - ++model[hash]; - } - i = 0; - } while (start < end && (found = read_code_points(cp, start, end))); - - return len; - } - - - static void constant(String data, String & res) - { - const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency(); - - std::unordered_map model; - calculateStats(data.data(), data.size(), readCodePoints, model); - - Float64 max_result = log(zero_frequency) * (max_string_size); - String poss_ans; - /// Go through the dictionary and find the charset with the highest weight - for (auto& item : encodings_freq) - { - Float64 score = naiveBayes(item.map, model, max_result); - if (max_result < score) - { - poss_ans = item.name; - max_result = score; - } - } - - /* In our dictionary we have lines with form: _ - * If we need to find language of data, we return - * If we need to find charset of data, we return . - */ - - size_t sep = poss_ans.find('_'); - if (detect_language) - { - res = poss_ans.erase(0, sep + 1); - } - else - { - res = poss_ans.erase(sep, poss_ans.size() - sep); + hash <<= 8; + hash += *(data + i); + ++model[hash]; } } - static void vector( const ColumnString::Chars & data, const ColumnString::Offsets & offsets, @@ -163,64 +79,53 @@ struct CharsetClassificationImpl { const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency(); - res_data.reserve(1024); + if (detect_language) + /// 2 chars for ISO code + 1 zero byte + res_data.reserve(offsets.size() * 3); + else + /// Mean charset length is 8 + res_data.reserve(offsets.size() * 8); + res_offsets.resize(offsets.size()); - size_t prev_offset = 0; size_t res_offset = 0; for (size_t i = 0; i < offsets.size(); ++i) { - const char * haystack = reinterpret_cast(&data[prev_offset]); - String str = haystack; + const UInt8 * str = data.data() + offsets[i - 1]; + const size_t str_len = offsets[i] - offsets[i - 1] - 1; - String poss_ans; + std::string_view res; - std::unordered_map model; - calculateStats(str.data(), str.size(), readCodePoints, model); + HashMap model; + calculateStats(str, str_len, model); - Float64 max_result = log(zero_frequency) * (max_string_size); - for (auto& item : encodings_freq) + /// Go through the dictionary and find the charset with the highest weight + Float64 max_result = log(zero_frequency) * (max_string_size); + for (const auto & item : encodings_freq) { Float64 score = naiveBayes(item.map, model, max_result); if (max_result < score) { max_result = score; - poss_ans = item.name; + res = detect_language ? item.lang : item.name; } } - size_t sep = poss_ans.find('_'); - String ans_str; + res_data.resize(res_offset + res.size() + 1); + memcpy(&res_data[res_offset], res.data(), res.size()); - if (detect_language) - { - ans_str = poss_ans.erase(0, sep + 1); - } - else - { - ans_str = poss_ans.erase(sep, poss_ans.size() - sep); - } - - ans_str = poss_ans; - - const auto res = ans_str.c_str(); - size_t cur_offset = offsets[i]; - - size_t ans_size = strlen(res); - res_data.resize(res_offset + ans_size + 1); - memcpy(&res_data[res_offset], res, ans_size); - res_offset += ans_size; - - res_data[res_offset] = 0; - ++res_offset; + res_data[res_offset + res.size()] = 0; + res_offset += res.size() + 1; res_offsets[i] = res_offset; - prev_offset = cur_offset; } } - + [[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &) + { + throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN); + } }; @@ -235,8 +140,8 @@ struct NameLanguageDetect }; -using FunctionCharsetDetect = FunctionsTextClassification, NameCharsetDetect>; -using FunctionLanguageDetect = FunctionsTextClassification, NameLanguageDetect>; +using FunctionCharsetDetect = FunctionStringToString, NameCharsetDetect, false>; +using FunctionLanguageDetect = FunctionStringToString, NameLanguageDetect, false>; void registerFunctionsCharsetClassification(FunctionFactory & factory) { diff --git a/src/Functions/FunctionsLanguageClassification.cpp b/src/Functions/FunctionsLanguageClassification.cpp index 23e289ec014..befed0311bf 100644 --- a/src/Functions/FunctionsLanguageClassification.cpp +++ b/src/Functions/FunctionsLanguageClassification.cpp @@ -4,7 +4,7 @@ #if USE_NLP -#include +#include #include #include @@ -34,9 +34,7 @@ extern const int ILLEGAL_COLUMN; struct LanguageClassificationImpl { - using ResultType = String; - - static String codeISO(std::string_view code_string) + static std::string_view codeISO(std::string_view code_string) { if (code_string.ends_with("-Latn")) code_string.remove_suffix(code_string.size() - 5); @@ -61,54 +59,44 @@ struct LanguageClassificationImpl if (code_string.size() != 2) return "other"; - return String(code_string); + return code_string; } - static void constant(const String & data, String & res) - { - bool is_reliable = true; - const char * str = data.c_str(); - auto lang = CLD2::DetectLanguage(str, strlen(str), true, &is_reliable); - res = codeISO(LanguageCode(lang)); - } - - static void vector( const ColumnString::Chars & data, const ColumnString::Offsets & offsets, ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets) { - res_data.reserve(1024); + /// Constant 3 is based on the fact that in general we need 2 characters for ISO code + 1 zero byte + res_data.reserve(offsets.size() * 3); res_offsets.resize(offsets.size()); - size_t prev_offset = 0; + bool is_reliable = true; size_t res_offset = 0; for (size_t i = 0; i < offsets.size(); ++i) { - const char * str = reinterpret_cast(&data[prev_offset]); - String res; - bool is_reliable = true; + const char * str = reinterpret_cast(data.data() + offsets[i - 1]); + const size_t str_len = offsets[i] - offsets[i - 1] - 1; - auto lang = CLD2::DetectLanguage(str, strlen(str), true, &is_reliable); - res = codeISO(LanguageCode(lang)); - - size_t cur_offset = offsets[i]; + auto lang = CLD2::DetectLanguage(str, str_len, true, &is_reliable); + auto res = codeISO(LanguageCode(lang)); res_data.resize(res_offset + res.size() + 1); memcpy(&res_data[res_offset], res.data(), res.size()); - res_offset += res.size(); - res_data[res_offset] = 0; - ++res_offset; + res_data[res_offset + res.size()] = 0; + res_offset += res.size() + 1; res_offsets[i] = res_offset; - prev_offset = cur_offset; } } - + [[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &) + { + throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN); + } }; class LanguageClassificationMixedDetect : public IFunction @@ -116,6 +104,9 @@ class LanguageClassificationMixedDetect : public IFunction public: static constexpr auto name = "detectLanguageMixed"; + /// Number of top results + static constexpr auto top_N = 3; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } String getName() const override { return name; } @@ -132,7 +123,7 @@ public: throw Exception( "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - return std::make_shared(std::make_shared(), std::make_shared()); + return std::make_shared(std::make_shared(), std::make_shared()); } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override @@ -145,8 +136,8 @@ public: "Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); - auto & input_data = col->getChars(); - auto & input_offsets = col->getOffsets(); + const auto & input_data = col->getChars(); + const auto & input_offsets = col->getOffsets(); /// Create and fill the result map. @@ -158,15 +149,15 @@ public: MutableColumnPtr values_data = value_type->createColumn(); MutableColumnPtr offsets = DataTypeNumber().createColumn(); - size_t total_elements = input_rows_count * 3; + size_t total_elements = input_rows_count * top_N; keys_data->reserve(total_elements); values_data->reserve(total_elements); offsets->reserve(input_rows_count); bool is_reliable = true; - CLD2::Language result_lang_top3[3]; - int32_t pc[3]; - int bytes[3]; + CLD2::Language result_lang_top3[top_N]; + int32_t pc[top_N]; + int bytes[top_N]; IColumn::Offset current_offset = 0; for (size_t i = 0; i < input_rows_count; ++i) @@ -176,16 +167,16 @@ public: CLD2::DetectLanguageSummary(str, str_len, true, result_lang_top3, pc, bytes, &is_reliable); - for (size_t j = 0; j < 3; ++j) + for (size_t j = 0; j < top_N; ++j) { auto res_str = LanguageClassificationImpl::codeISO(LanguageCode(result_lang_top3[j])); - int32_t res_int = static_cast(pc[j]); + Float32 res_float = static_cast(pc[j]) / 100; keys_data->insertData(res_str.data(), res_str.size()); - values_data->insertData(reinterpret_cast(&res_int), sizeof(res_int)); + values_data->insertData(reinterpret_cast(&res_float), sizeof(res_float)); } - current_offset += 3; + current_offset += top_N; offsets->insert(current_offset); } @@ -203,7 +194,7 @@ struct NameLanguageUTF8Detect }; -using FunctionLanguageUTF8Detect = FunctionsTextClassification; +using FunctionLanguageUTF8Detect = FunctionStringToString; void registerFunctionLanguageDetectUTF8(FunctionFactory & factory) { diff --git a/src/Functions/FunctionsProgrammingClassification.cpp b/src/Functions/FunctionsProgrammingClassification.cpp index 435d629dafc..5f114a35d9c 100644 --- a/src/Functions/FunctionsProgrammingClassification.cpp +++ b/src/Functions/FunctionsProgrammingClassification.cpp @@ -1,12 +1,18 @@ -#include #include #include -#include +#include #include +#include namespace DB { + +namespace ErrorCodes +{ + extern const int ILLEGAL_COLUMN; +} + /** * Determine the programming language from the source code. * We calculate all the unigrams and bigrams of commands in the source code. @@ -15,16 +21,16 @@ namespace DB */ struct ProgrammingClassificationImpl { - - using ResultType = String; /// Calculate total weight - static ALWAYS_INLINE inline Float64 stateMachine(const FrequencyHolder::Map & standard, std::unordered_map & model) + static ALWAYS_INLINE inline Float64 stateMachine( + const FrequencyHolder::Map & standard, + const std::unordered_map & model) { Float64 res = 0; - for (auto & el : model) + for (const auto & el : model) { /// Try to find each n-gram in dictionary - auto it = standard.find(el.first); + const auto * it = standard.find(el.first); if (it != standard.end()) { res += el.second * it->getMapped(); @@ -33,104 +39,44 @@ struct ProgrammingClassificationImpl return res; } - - static void constant(String data, String & res) - { - auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency(); - std::unordered_map data_freq; - - String prev_command; - String command; - /// Select all commands from the string - for (size_t i = 0; i < data.size();) - { - /// Assume that all commands are split by spaces - if (!isspace(data[i])) - { - command.push_back(data[i]); - ++i; - - while ((i < data.size()) && (!isspace(data[i]))) - { - command.push_back(data[i]); - ++i; - } - if (prev_command == "") - { - prev_command = command; - } - else - { - data_freq[prev_command + command] += 1; - data_freq[prev_command] += 1; - prev_command = command; - } - command = ""; - } - else - { - ++i; - } - } - - String most_liked; - Float64 max_result = 0; - /// Iterate over all programming languages ​​and find the language with the highest weight - for (auto& item : programming_freq) - { - Float64 result = stateMachine(item.map, data_freq); - if (result > max_result) - { - max_result = result; - most_liked = item.name; - } - } - /// If all weights are zero, then we assume that the language is undefined - if (most_liked == "") - { - most_liked = "Undefined"; - } - res = most_liked; - } - - static void vector( const ColumnString::Chars & data, const ColumnString::Offsets & offsets, ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets) { - auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency(); + const auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency(); - res_data.reserve(1024); + /// Constant 5 is arbitrary + res_data.reserve(offsets.size() * 5); res_offsets.resize(offsets.size()); - size_t prev_offset = 0; size_t res_offset = 0; for (size_t i = 0; i < offsets.size(); ++i) { - const char * haystack = reinterpret_cast(&data[prev_offset]); + const UInt8 * str = data.data() + offsets[i - 1]; + const size_t str_len = offsets[i] - offsets[i - 1] - 1; + std::unordered_map data_freq; - String str_data = haystack; String prev_command; String command; /// Select all commands from the string - for (size_t ind = 0; ind < str_data.size();) + for (size_t ind = 0; ind < str_len;) { /// Assume that all commands are split by spaces - if (!isspace(str_data[ind])) + if (!isspace(str[ind])) { - command.push_back(str_data[ind]); + command.push_back(str[ind]); ++ind; - while ((ind < str_data.size()) && (!isspace(str_data[ind]))) + while ((ind < str_len) && (!isspace(str[ind]))) { - command.push_back(str_data[ind]); + command.push_back(str[ind]); ++ind; } - if (prev_command == "") + if (prev_command.empty()) { prev_command = command; } @@ -148,39 +94,36 @@ struct ProgrammingClassificationImpl } } - String most_liked; + String res; Float64 max_result = 0; /// Iterate over all programming languages ​​and find the language with the highest weight - for (auto& item : programming_freq) + for (const auto & item : programming_freq) { Float64 result = stateMachine(item.map, data_freq); if (result > max_result) { max_result = result; - most_liked = item.name; + res = item.name; } } /// If all weights are zero, then we assume that the language is undefined - if (most_liked == "") - { - most_liked = "Undefined"; - } + if (res.empty()) + res = "Undefined"; - const auto res = most_liked.c_str(); - size_t cur_offset = offsets[i]; - size_t ans_size = strlen(res); - res_data.resize(res_offset + ans_size + 1); - memcpy(&res_data[res_offset], res, ans_size); - res_offset += ans_size; + res_data.resize(res_offset + res.size() + 1); + memcpy(&res_data[res_offset], res.data(), res.size()); - res_data[res_offset] = 0; - ++res_offset; + res_data[res_offset + res.size()] = 0; + res_offset += res.size() + 1; res_offsets[i] = res_offset; - prev_offset = cur_offset; } } + [[noreturn]] static void vectorFixed(const ColumnString::Chars &, size_t, ColumnString::Chars &) + { + throw Exception("Cannot apply function detectProgrammingLanguage to fixed string.", ErrorCodes::ILLEGAL_COLUMN); + } }; struct NameGetProgramming @@ -189,7 +132,7 @@ struct NameGetProgramming }; -using FunctionGetProgramming = FunctionsTextClassification; +using FunctionGetProgramming = FunctionStringToString; void registerFunctionsProgrammingClassification(FunctionFactory & factory) { diff --git a/src/Functions/FunctionsTextClassification.h b/src/Functions/FunctionsTextClassification.h deleted file mode 100644 index 3d7b9903be1..00000000000 --- a/src/Functions/FunctionsTextClassification.h +++ /dev/null @@ -1,86 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace DB -{ -/** Functions for text classification: - * - * detectCharset(string data) - detect charset of data. - * Returns string name of most likely charset. - * - * detectLanguage(string data) - detect language of data in various encodings (not UTF-8) - * - * getTonality(string data) - defines the emotional coloring of the text. - * Returns NEG if text is negative, POS if text is positive or NEUT if text is neutral. - * - * getProgrammingLanguage(string data) - detect programming language - */ -namespace ErrorCodes -{ -extern const int ILLEGAL_TYPE_OF_ARGUMENT; -extern const int ILLEGAL_COLUMN; -} - -template -class FunctionsTextClassification : public IFunction -{ -public: - static constexpr auto name = Name::name; - - static FunctionPtr create(ContextPtr) { return std::make_shared(); } - - String getName() const override { return name; } - - size_t getNumberOfArguments() const override { return 1; } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override - { - if (!isString(arguments[0])) - throw Exception( - "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - return arguments[0]; - } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override - { - using ResultType = typename Impl::ResultType; - - const ColumnPtr & column = arguments[0].column; - - const ColumnConst * col_const = typeid_cast(&*column); - - if (col_const) - { - ResultType res; - Impl::constant(col_const->getValue(), res); - return result_type->createColumnConst(col_const->size(), toField(res)); - } - - - if (const ColumnString * col = checkAndGetColumn(column.get())) - { - auto col_res = ColumnString::create(); - ColumnString::Chars & vec_res = col_res->getChars(); - ColumnString::Offsets & offsets_res = col_res->getOffsets(); - Impl::vector(col->getChars(), col->getOffsets(), vec_res, offsets_res); - return col_res; - } - else - { - throw Exception( - "Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(), - ErrorCodes::ILLEGAL_COLUMN); - } - } -}; - -} diff --git a/src/Functions/FunctionsTonalityClassification.cpp b/src/Functions/FunctionsTonalityClassification.cpp index c6baf1062f3..22c9bdc9137 100644 --- a/src/Functions/FunctionsTonalityClassification.cpp +++ b/src/Functions/FunctionsTonalityClassification.cpp @@ -1,163 +1,115 @@ -#include -#include #include +#include #include -#include -#include -#include +#include #include namespace DB { + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + /** * Determines the sentiment of text data. - * Uses a marked-up sentiment dictionary, each word has a tonality ranging from -3 to 3. + * Uses a marked-up sentiment dictionary, each word has a tonality ranging from -12 to 6. * For each text, calculate the average sentiment value of its words and return NEG, POS or NEUT */ struct TonalityClassificationImpl { - - using ResultType = String; - - - static String get_tonality(const Float64 & tonality_level) + static Float32 detectTonality(const UInt8 * str, const size_t str_len, const FrequencyHolder::Map & emotional_dict) { - if (tonality_level < 0.15) { return "NEG"; } - if (tonality_level > 0.45) { return "POS"; } - return "NEUT"; - } - - static void constant(String data, String & res) - { - const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict(); - Float64 weight = 0; - Float64 count_words = 0; + UInt64 count_words = 0; - String answer; String word; /// Select all Russian words from the string - for (size_t i = 0; i < data.size();) + for (size_t ind = 0; ind < str_len;) { - /// Assume that all non-Ascii characters are Russian letters - if (!isASCII(data[i])) + /// Assume that all non-ASCII characters are Russian letters + if (!isASCII(str[ind])) { - word.push_back(data[i]); - ++i; + word.push_back(str[ind]); + ++ind; - while ((i < data.size()) && (!isASCII(data[i]))) + while ((ind < str_len) && (!isASCII(str[ind]))) { - word.push_back(data[i]); - ++i; + word.push_back(str[ind]); + ++ind; } /// Try to find a russian word in the tonality dictionary - auto it = emotional_dict.find(word); + const auto * it = emotional_dict.find(word); if (it != emotional_dict.end()) { count_words += 1; weight += it->getMapped(); } - word = ""; + word.clear(); } else { - ++i; + ++ind; } } - /// Calculate average value of tonality - Float64 total_tonality = weight / count_words; - res += get_tonality(total_tonality); + /// Calculate average value of tonality. + /// Convert values -12..6 to -1..1 + return std::max(weight / count_words / 6, -1.0); } + /// If the function will return constant value for FixedString data type. + static constexpr auto is_fixed_to_constant = false; static void vector( const ColumnString::Chars & data, const ColumnString::Offsets & offsets, - ColumnString::Chars & res_data, - ColumnString::Offsets & res_offsets) + PaddedPODArray & res) { const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict(); - res_data.reserve(1024); - res_offsets.resize(offsets.size()); - + size_t size = offsets.size(); size_t prev_offset = 0; - size_t res_offset = 0; - - for (size_t i = 0; i < offsets.size(); ++i) + for (size_t i = 0; i < size; ++i) { - const char * haystack = reinterpret_cast(&data[prev_offset]); - String str = haystack; - - String buf; - - Float64 weight = 0; - Float64 count_words = 0; - - - String answer; - String word; - /// Select all Russian words from the string - for (size_t ind = 0; ind < str.size();) - { - if (!isASCII(str[ind])) - { - word.push_back(str[ind]); - ++ind; - - while ((ind < str.size()) && (!isASCII(str[ind]))) - { - word.push_back(str[ind]); - ++ind; - } - /// Try to find a russian word in the tonality dictionary - auto it = emotional_dict.find(word); - if (it != emotional_dict.end()) - { - count_words += 1; - weight += it->getMapped(); - } - word = ""; - } - else - { - ++ind; - } - } - /// Calculate average value of tonality - Float64 total_tonality = weight / count_words; - buf = get_tonality(total_tonality); - - const auto res = buf.c_str(); - size_t cur_offset = offsets[i]; - size_t ans_size = strlen(res); - res_data.resize(res_offset + ans_size + 1); - memcpy(&res_data[res_offset], res, ans_size); - res_offset += ans_size; - - res_data[res_offset] = 0; - ++res_offset; - - res_offsets[i] = res_offset; - prev_offset = cur_offset; + res[i] = detectTonality(data.data() + prev_offset, offsets[i] - 1 - prev_offset, emotional_dict); + prev_offset = offsets[i]; } } + static void vectorFixedToConstant(const ColumnString::Chars & /*data*/, size_t /*n*/, Float32 & /*res*/) {} + static void vectorFixedToVector(const ColumnString::Chars & data, size_t n, PaddedPODArray & res) + { + const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict(); + + size_t size = data.size() / n; + for (size_t i = 0; i < size; ++i) + res[i] = detectTonality(data.data() + i * n, n, emotional_dict); + } + + [[noreturn]] static void array(const ColumnString::Offsets &, PaddedPODArray &) + { + throw Exception("Cannot apply function detectTonality to Array argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + [[noreturn]] static void uuid(const ColumnUUID::Container &, size_t &, PaddedPODArray &) + { + throw Exception("Cannot apply function detectTonality to UUID argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } }; -struct NameGetTonality +struct NameDetectTonality { static constexpr auto name = "detectTonality"; }; - -using FunctionGetTonality = FunctionsTextClassification; +using FunctionDetectTonality = FunctionStringOrArrayToT; void registerFunctionsTonalityClassification(FunctionFactory & factory) { - factory.registerFunction(); + factory.registerFunction(); } } diff --git a/src/Interpreters/CMakeLists.txt b/src/Interpreters/CMakeLists.txt index 4bc1c680ed3..3f3c58ef63d 100644 --- a/src/Interpreters/CMakeLists.txt +++ b/src/Interpreters/CMakeLists.txt @@ -2,9 +2,9 @@ if (ENABLE_TESTS) add_subdirectory(tests) endif() -# if (ENABLE_EXAMPLES) +if (ENABLE_EXAMPLES) add_subdirectory(examples) -# endif() +endif() if (ENABLE_FUZZING) add_subdirectory(fuzzers) diff --git a/tests/queries/0_stateless/02133_classification.reference b/tests/queries/0_stateless/02133_classification.reference index 7332a2a423e..5607f6829f9 100644 --- a/tests/queries/0_stateless/02133_classification.reference +++ b/tests/queries/0_stateless/02133_classification.reference @@ -3,10 +3,10 @@ en fr ja zh -{'ja':62,'fr':36,'un':0} +{'ja':0.62,'fr':0.36,'un':0} ISO-8859-1 -English -POS -NEG -POS +en +0.465 +-0.57647777 +0.050505556 C++ diff --git a/tests/queries/0_stateless/02133_classification.sql b/tests/queries/0_stateless/02133_classification.sql index acd33fe46ad..8c4a4f11360 100644 --- a/tests/queries/0_stateless/02133_classification.sql +++ b/tests/queries/0_stateless/02133_classification.sql @@ -11,8 +11,8 @@ SELECT detectLanguageMixed('二兎を追う者は一兎をも得ず二兎を追 SELECT detectCharset('Plain English'); SELECT detectLanguageUnknown('Plain English'); -SELECT detectTonality('Милая кошка'); -SELECT detectTonality('Злой человек'); -SELECT detectTonality('Обычная прогулка по ближайшему парку'); +SELECT detectTonality('милая кошка'); +SELECT detectTonality('ненависть к людям'); +SELECT detectTonality('обычная прогулка по ближайшему парку'); SELECT detectProgrammingLanguage('#include ');