Merge pull request #33314 from evillique/classification

Merge functions for text classification
This commit is contained in:
Nikolay Degterinsky 2022-01-27 17:15:08 +03:00 committed by GitHub
commit c5ca5b608e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1087 additions and 0 deletions

6
.gitmodules vendored
View File

@ -217,6 +217,9 @@
[submodule "contrib/yaml-cpp"]
path = contrib/yaml-cpp
url = https://github.com/ClickHouse-Extras/yaml-cpp.git
[submodule "contrib/cld2"]
path = contrib/cld2
url = https://github.com/ClickHouse-Extras/cld2.git
[submodule "contrib/libstemmer_c"]
path = contrib/libstemmer_c
url = https://github.com/ClickHouse-Extras/libstemmer_c.git
@ -247,6 +250,9 @@
[submodule "contrib/sysroot"]
path = contrib/sysroot
url = https://github.com/ClickHouse-Extras/sysroot.git
[submodule "contrib/nlp-data"]
path = contrib/nlp-data
url = https://github.com/ClickHouse-Extras/nlp-data.git
[submodule "contrib/hive-metastore"]
path = contrib/hive-metastore
url = https://github.com/ClickHouse-Extras/hive-metastore

View File

@ -140,6 +140,8 @@ if (ENABLE_NLP)
add_contrib (libstemmer-c-cmake libstemmer_c)
add_contrib (wordnet-blast-cmake wordnet-blast)
add_contrib (lemmagen-c-cmake lemmagen-c)
add_contrib (nlp-data-cmake nlp-data)
add_contrib (cld2-cmake cld2)
endif()
add_contrib (sqlite-cmake sqlite-amalgamation)

1
contrib/cld2 vendored Submodule

@ -0,0 +1 @@
Subproject commit bc6d493a2f64ed1fc1c4c4b4294a542a04e04217

View File

@ -0,0 +1,33 @@
set (LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/cld2")
set (SRCS
"${LIBRARY_DIR}/internal/cldutil.cc"
"${LIBRARY_DIR}/internal/compact_lang_det.cc"
"${LIBRARY_DIR}/internal/cldutil_shared.cc"
"${LIBRARY_DIR}/internal/compact_lang_det_hint_code.cc"
"${LIBRARY_DIR}/internal/compact_lang_det_impl.cc"
"${LIBRARY_DIR}/internal/debug.cc"
"${LIBRARY_DIR}/internal/fixunicodevalue.cc"
"${LIBRARY_DIR}/internal/generated_entities.cc"
"${LIBRARY_DIR}/internal/generated_language.cc"
"${LIBRARY_DIR}/internal/generated_ulscript.cc"
"${LIBRARY_DIR}/internal/getonescriptspan.cc"
"${LIBRARY_DIR}/internal/lang_script.cc"
"${LIBRARY_DIR}/internal/offsetmap.cc"
"${LIBRARY_DIR}/internal/scoreonescriptspan.cc"
"${LIBRARY_DIR}/internal/tote.cc"
"${LIBRARY_DIR}/internal/utf8statetable.cc"
"${LIBRARY_DIR}/internal/cld_generated_cjk_uni_prop_80.cc"
"${LIBRARY_DIR}/internal/cld2_generated_cjk_compatible.cc"
"${LIBRARY_DIR}/internal/cld_generated_cjk_delta_bi_4.cc"
"${LIBRARY_DIR}/internal/generated_distinct_bi_0.cc"
"${LIBRARY_DIR}/internal/cld2_generated_quadchrome_2.cc"
"${LIBRARY_DIR}/internal/cld2_generated_deltaoctachrome.cc"
"${LIBRARY_DIR}/internal/cld2_generated_distinctoctachrome.cc"
"${LIBRARY_DIR}/internal/cld_generated_score_quad_octa_2.cc"
)
add_library(_cld2 ${SRCS})
set_property(TARGET _cld2 PROPERTY POSITION_INDEPENDENT_CODE ON)
target_compile_options (_cld2 PRIVATE -Wno-reserved-id-macro -Wno-c++11-narrowing)
target_include_directories(_cld2 SYSTEM BEFORE PUBLIC "${LIBRARY_DIR}/public")
add_library(ch_contrib::cld2 ALIAS _cld2)

1
contrib/nlp-data vendored Submodule

@ -0,0 +1 @@
Subproject commit 5591f91f5e748cba8fb9ef81564176feae774853

View File

@ -0,0 +1,15 @@
include(${ClickHouse_SOURCE_DIR}/cmake/embed_binary.cmake)
set(LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/nlp-data")
add_library (_nlp_data INTERFACE)
clickhouse_embed_binaries(
TARGET nlp_dictionaries
RESOURCE_DIR "${LIBRARY_DIR}"
RESOURCES charset.zst tonality_ru.zst programming.zst
)
add_dependencies(_nlp_data nlp_dictionaries)
target_link_libraries(_nlp_data INTERFACE "-Wl,${WHOLE_ARCHIVE} $<TARGET_FILE:nlp_dictionaries> -Wl,${NO_WHOLE_ARCHIVE}")
add_library(ch_contrib::nlp_data ALIAS _nlp_data)

View File

@ -506,6 +506,7 @@ if (ENABLE_NLP)
dbms_target_link_libraries (PUBLIC ch_contrib::stemmer)
dbms_target_link_libraries (PUBLIC ch_contrib::wnb)
dbms_target_link_libraries (PUBLIC ch_contrib::lemmagen)
dbms_target_link_libraries (PUBLIC ch_contrib::nlp_data)
endif()
if (TARGET ch_contrib::bzip2)
@ -558,3 +559,4 @@ if (ENABLE_TESTS)
add_check(unit_tests_dbms)
endif ()

View File

@ -0,0 +1,252 @@
#pragma once
#include <Common/Arena.h>
#include <Common/getResource.h>
#include <Common/HashTable/HashMap.h>
#include <Common/StringUtils/StringUtils.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <IO/readFloatText.h>
#include <IO/ZstdInflatingReadBuffer.h>
#include <base/StringRef.h>
#include <base/logger_useful.h>
#include <string_view>
#include <unordered_map>
namespace DB
{
namespace ErrorCodes
{
extern const int FILE_DOESNT_EXIST;
}
/// FrequencyHolder class is responsible for storing and loading dictionaries
/// needed for text classification functions:
///
/// 1. detectLanguageUnknown
/// 2. detectCharset
/// 3. detectTonality
/// 4. detectProgrammingLanguage
class FrequencyHolder
{
public:
struct Language
{
String name;
HashMap<StringRef, Float64> map;
};
struct Encoding
{
String name;
String lang;
HashMap<UInt16, Float64> map;
};
public:
using Map = HashMap<StringRef, Float64>;
using Container = std::vector<Language>;
using EncodingMap = HashMap<UInt16, Float64>;
using EncodingContainer = std::vector<Encoding>;
static FrequencyHolder & getInstance()
{
static FrequencyHolder instance;
return instance;
}
void loadEncodingsFrequency()
{
Poco::Logger * log = &Poco::Logger::get("EncodingsFrequency");
LOG_TRACE(log, "Loading embedded charset frequencies");
auto resource = getResource("charset.zst");
if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded charset frequencies");
String line;
UInt16 bigram;
Float64 frequency;
String charset_name;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
ZstdInflatingReadBuffer in(std::move(buf));
while (!in.eof())
{
readString(line, in);
in.ignore();
if (line.empty())
continue;
ReadBufferFromString buf_line(line);
// Start loading a new charset
if (line.starts_with("// "))
{
// Skip "// "
buf_line.ignore(3);
readString(charset_name, buf_line);
/* In our dictionary we have lines with form: <Language>_<Charset>
* If we need to find language of data, we return <Language>
* If we need to find charset of data, we return <Charset>.
*/
size_t sep = charset_name.find('_');
Encoding enc;
enc.lang = charset_name.substr(0, sep);
enc.name = charset_name.substr(sep + 1);
encodings_freq.push_back(std::move(enc));
}
else
{
readIntText(bigram, buf_line);
buf_line.ignore();
readFloatText(frequency, buf_line);
encodings_freq.back().map[bigram] = frequency;
}
}
LOG_TRACE(log, "Charset frequencies was added, charsets count: {}", encodings_freq.size());
}
void loadEmotionalDict()
{
Poco::Logger * log = &Poco::Logger::get("EmotionalDict");
LOG_TRACE(log, "Loading embedded emotional dictionary");
auto resource = getResource("tonality_ru.zst");
if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded emotional dictionary");
String line;
String word;
Float64 tonality;
size_t count = 0;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
ZstdInflatingReadBuffer in(std::move(buf));
while (!in.eof())
{
readString(line, in);
in.ignore();
if (line.empty())
continue;
ReadBufferFromString buf_line(line);
readStringUntilWhitespace(word, buf_line);
buf_line.ignore();
readFloatText(tonality, buf_line);
StringRef ref{string_pool.insert(word.data(), word.size()), word.size()};
emotional_dict[ref] = tonality;
++count;
}
LOG_TRACE(log, "Emotional dictionary was added. Word count: {}", std::to_string(count));
}
void loadProgrammingFrequency()
{
Poco::Logger * log = &Poco::Logger::get("ProgrammingFrequency");
LOG_TRACE(log, "Loading embedded programming languages frequencies loading");
auto resource = getResource("programming.zst");
if (resource.empty())
throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "There is no embedded programming languages frequencies");
String line;
String bigram;
Float64 frequency;
String programming_language;
auto buf = std::make_unique<ReadBufferFromMemory>(resource.data(), resource.size());
ZstdInflatingReadBuffer in(std::move(buf));
while (!in.eof())
{
readString(line, in);
in.ignore();
if (line.empty())
continue;
ReadBufferFromString buf_line(line);
// Start loading a new language
if (line.starts_with("// "))
{
// Skip "// "
buf_line.ignore(3);
readString(programming_language, buf_line);
Language lang;
lang.name = programming_language;
programming_freq.push_back(std::move(lang));
}
else
{
readStringUntilWhitespace(bigram, buf_line);
buf_line.ignore();
readFloatText(frequency, buf_line);
StringRef ref{string_pool.insert(bigram.data(), bigram.size()), bigram.size()};
programming_freq.back().map[ref] = frequency;
}
}
LOG_TRACE(log, "Programming languages frequencies was added");
}
const Map & getEmotionalDict()
{
std::lock_guard lock(mutex);
if (emotional_dict.empty())
loadEmotionalDict();
return emotional_dict;
}
const EncodingContainer & getEncodingsFrequency()
{
std::lock_guard lock(mutex);
if (encodings_freq.empty())
loadEncodingsFrequency();
return encodings_freq;
}
const Container & getProgrammingFrequency()
{
std::lock_guard lock(mutex);
if (programming_freq.empty())
loadProgrammingFrequency();
return programming_freq;
}
private:
Arena string_pool;
Map emotional_dict;
Container programming_freq;
EncodingContainer encodings_freq;
std::mutex mutex;
};
}

View File

@ -76,6 +76,10 @@ endif()
target_link_libraries(clickhouse_functions PRIVATE ch_contrib::lz4)
if (ENABLE_NLP)
target_link_libraries(clickhouse_functions PRIVATE ch_contrib::cld2)
endif()
if (TARGET ch_contrib::h3)
target_link_libraries (clickhouse_functions PRIVATE ch_contrib::h3)
endif()

View File

@ -0,0 +1,142 @@
#include <Common/FrequencyHolder.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsTextClassification.h>
#include <memory>
#include <unordered_map>
namespace DB
{
/* Determine language and charset of text data. For each text, we build the distribution of bigrams bytes.
* Then we use marked-up dictionaries with distributions of bigram bytes of various languages and charsets.
* Using a naive Bayesian classifier, find the most likely charset and language and return it
*/
template <bool detect_language>
struct CharsetClassificationImpl
{
/* We need to solve zero-frequency problem for Naive Bayes Classifier
* If the bigram is not found in the text, we assume that the probability of its meeting is 1e-06.
* 1e-06 is minimal value in our marked-up dictionary.
*/
static constexpr Float64 zero_frequency = 1e-06;
/// If the data size is bigger than this, behaviour is unspecified for this function.
static constexpr size_t max_string_size = 1u << 15;
static ALWAYS_INLINE inline Float64 naiveBayes(
const FrequencyHolder::EncodingMap & standard,
const HashMap<UInt16, UInt64> & model,
Float64 max_result)
{
Float64 res = 0;
for (const auto & el : model)
{
/// Try to find bigram in the dictionary.
const auto * it = standard.find(el.getKey());
if (it != standard.end())
{
res += el.getMapped() * log(it->getMapped());
} else
{
res += el.getMapped() * log(zero_frequency);
}
/// If at some step the result has become less than the current maximum, then it makes no sense to count it fully.
if (res < max_result)
{
return res;
}
}
return res;
}
/// Сount how many times each bigram occurs in the text.
static ALWAYS_INLINE inline void calculateStats(
const UInt8 * data,
const size_t size,
HashMap<UInt16, UInt64> & model)
{
UInt16 hash = 0;
for (size_t i = 0; i < size; ++i)
{
hash <<= 8;
hash += *(data + i);
++model[hash];
}
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
{
const auto & encodings_freq = FrequencyHolder::getInstance().getEncodingsFrequency();
if (detect_language)
/// 2 chars for ISO code + 1 zero byte
res_data.reserve(offsets.size() * 3);
else
/// Mean charset length is 8
res_data.reserve(offsets.size() * 8);
res_offsets.resize(offsets.size());
size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
const UInt8 * str = data.data() + offsets[i - 1];
const size_t str_len = offsets[i] - offsets[i - 1] - 1;
std::string_view res;
HashMap<UInt16, UInt64> model;
calculateStats(str, str_len, model);
/// Go through the dictionary and find the charset with the highest weight
Float64 max_result = log(zero_frequency) * (max_string_size);
for (const auto & item : encodings_freq)
{
Float64 score = naiveBayes(item.map, model, max_result);
if (max_result < score)
{
max_result = score;
res = detect_language ? item.lang : item.name;
}
}
res_data.resize(res_offset + res.size() + 1);
memcpy(&res_data[res_offset], res.data(), res.size());
res_data[res_offset + res.size()] = 0;
res_offset += res.size() + 1;
res_offsets[i] = res_offset;
}
}
};
struct NameDetectCharset
{
static constexpr auto name = "detectCharset";
};
struct NameDetectLanguageUnknown
{
static constexpr auto name = "detectLanguageUnknown";
};
using FunctionDetectCharset = FunctionTextClassificationString<CharsetClassificationImpl<false>, NameDetectCharset>;
using FunctionDetectLanguageUnknown = FunctionTextClassificationString<CharsetClassificationImpl<true>, NameDetectLanguageUnknown>;
void registerFunctionDetectCharset(FunctionFactory & factory)
{
factory.registerFunction<FunctionDetectCharset>();
factory.registerFunction<FunctionDetectLanguageUnknown>();
}
}

View File

@ -0,0 +1,231 @@
#include "config_functions.h"
#if USE_NLP
#include <Columns/ColumnMap.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnsNumber.h>
#include <Common/isValidUTF8.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsTextClassification.h>
#include <Interpreters/Context.h>
#include <compact_lang_det.h>
namespace DB
{
/* Determine language of Unicode UTF-8 text.
* Uses the cld2 library https://github.com/CLD2Owners/cld2
*/
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int SUPPORT_IS_DISABLED;
}
struct FunctionDetectLanguageImpl
{
static ALWAYS_INLINE inline std::string_view codeISO(std::string_view code_string)
{
if (code_string.ends_with("-Latn"))
code_string.remove_suffix(code_string.size() - 5);
if (code_string.ends_with("-Hant"))
code_string.remove_suffix(code_string.size() - 5);
// Old deprecated codes
if (code_string == "iw")
return "he";
if (code_string == "jw")
return "jv";
if (code_string == "in")
return "id";
if (code_string == "mo")
return "ro";
// Some languages do not have 2 letter codes, for example code for Cebuano is ceb
if (code_string.size() != 2)
return "other";
return code_string;
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
{
/// Constant 3 is based on the fact that in general we need 2 characters for ISO code + 1 zero byte
res_data.reserve(offsets.size() * 3);
res_offsets.resize(offsets.size());
bool is_reliable;
size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
const UInt8 * str = data.data() + offsets[i - 1];
const size_t str_len = offsets[i] - offsets[i - 1] - 1;
std::string_view res;
if (UTF8::isValidUTF8(str, str_len))
{
auto lang = CLD2::DetectLanguage(reinterpret_cast<const char *>(str), str_len, true, &is_reliable);
res = codeISO(LanguageCode(lang));
}
else
{
res = "un";
}
res_data.resize(res_offset + res.size() + 1);
memcpy(&res_data[res_offset], res.data(), res.size());
res_data[res_offset + res.size()] = 0;
res_offset += res.size() + 1;
res_offsets[i] = res_offset;
}
}
};
class FunctionDetectLanguageMixed : public IFunction
{
public:
static constexpr auto name = "detectLanguageMixed";
/// Number of top results
static constexpr auto top_N = 3;
static FunctionPtr create(ContextPtr context)
{
if (!context->getSettingsRef().allow_experimental_nlp_functions)
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED,
"Natural language processing function '{}' is experimental. Set `allow_experimental_nlp_functions` setting to enable it", name);
return std::make_shared<FunctionDetectLanguageMixed>();
}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of function {}. Must be String.",
arguments[0]->getName(), getName());
return std::make_shared<DataTypeMap>(std::make_shared<DataTypeString>(), std::make_shared<DataTypeFloat32>());
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
const auto & column = arguments[0].column;
const ColumnString * col = checkAndGetColumn<ColumnString>(column.get());
if (!col)
throw Exception(
"Illegal columns " + arguments[0].column->getName() + " of arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
const auto & input_data = col->getChars();
const auto & input_offsets = col->getOffsets();
/// Create and fill the result map.
const auto & result_type_map = static_cast<const DataTypeMap &>(*result_type);
const DataTypePtr & key_type = result_type_map.getKeyType();
const DataTypePtr & value_type = result_type_map.getValueType();
MutableColumnPtr keys_data = key_type->createColumn();
MutableColumnPtr values_data = value_type->createColumn();
MutableColumnPtr offsets = DataTypeNumber<IColumn::Offset>().createColumn();
size_t total_elements = input_rows_count * top_N;
keys_data->reserve(total_elements);
values_data->reserve(total_elements);
offsets->reserve(input_rows_count);
bool is_reliable;
CLD2::Language result_lang_top3[top_N];
int32_t pc[top_N];
int bytes[top_N];
IColumn::Offset current_offset = 0;
for (size_t i = 0; i < input_rows_count; ++i)
{
const UInt8 * str = input_data.data() + input_offsets[i - 1];
const size_t str_len = input_offsets[i] - input_offsets[i - 1] - 1;
if (UTF8::isValidUTF8(str, str_len))
{
CLD2::DetectLanguageSummary(reinterpret_cast<const char *>(str), str_len, true, result_lang_top3, pc, bytes, &is_reliable);
for (size_t j = 0; j < top_N; ++j)
{
if (pc[j] == 0)
break;
auto res_str = FunctionDetectLanguageImpl::codeISO(LanguageCode(result_lang_top3[j]));
Float32 res_float = static_cast<Float32>(pc[j]) / 100;
keys_data->insertData(res_str.data(), res_str.size());
values_data->insertData(reinterpret_cast<const char *>(&res_float), sizeof(res_float));
++current_offset;
}
}
else
{
std::string_view res_str = "un";
Float32 res_float = 0;
keys_data->insertData(res_str.data(), res_str.size());
values_data->insertData(reinterpret_cast<const char *>(&res_float), sizeof(res_float));
++current_offset;
}
offsets->insert(current_offset);
}
auto nested_column = ColumnArray::create(
ColumnTuple::create(Columns{std::move(keys_data), std::move(values_data)}),
std::move(offsets));
return ColumnMap::create(nested_column);
}
};
struct NameDetectLanguage
{
static constexpr auto name = "detectLanguage";
};
using FunctionDetectLanguage = FunctionTextClassificationString<FunctionDetectLanguageImpl, NameDetectLanguage>;
void registerFunctionsDetectLanguage(FunctionFactory & factory)
{
factory.registerFunction<FunctionDetectLanguage>();
factory.registerFunction<FunctionDetectLanguageMixed>();
}
}
#endif

View File

@ -0,0 +1,120 @@
#include <Common/FrequencyHolder.h>
#include <Common/StringUtils/StringUtils.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsTextClassification.h>
#include <unordered_map>
#include <string_view>
namespace DB
{
/**
* Determine the programming language from the source code.
* We calculate all the unigrams and bigrams of commands in the source code.
* Then using a marked-up dictionary with weights of unigrams and bigrams of commands for various programming languages
* Find the biggest weight of the programming language and return it
*/
struct FunctionDetectProgrammingLanguageImpl
{
/// Calculate total weight
static ALWAYS_INLINE inline Float64 stateMachine(
const FrequencyHolder::Map & standard,
const std::unordered_map<String, Float64> & model)
{
Float64 res = 0;
for (const auto & el : model)
{
/// Try to find each n-gram in dictionary
const auto * it = standard.find(el.first);
if (it != standard.end())
res += el.second * it->getMapped();
}
return res;
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets)
{
const auto & programming_freq = FrequencyHolder::getInstance().getProgrammingFrequency();
/// Constant 5 is arbitrary
res_data.reserve(offsets.size() * 5);
res_offsets.resize(offsets.size());
size_t res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
const UInt8 * str = data.data() + offsets[i - 1];
const size_t str_len = offsets[i] - offsets[i - 1] - 1;
std::unordered_map<String, Float64> data_freq;
StringRef prev_command;
StringRef command;
/// Select all commands from the string
for (size_t ind = 0; ind < str_len; ++ind)
{
/// Assume that all commands are split by spaces
if (isWhitespaceASCII(str[ind]))
continue;
size_t prev_ind = ind;
while (ind < str_len && !isWhitespaceASCII(str[ind]))
++ind;
command = {str + prev_ind, ind - prev_ind};
/// We add both unigrams and bigrams to later search for them in the dictionary
if (prev_command.data)
data_freq[prev_command.toString() + command.toString()] += 1;
data_freq[command.toString()] += 1;
prev_command = command;
}
std::string_view res;
Float64 max_result = 0;
/// Iterate over all programming languages and find the language with the highest weight
for (const auto & item : programming_freq)
{
Float64 result = stateMachine(item.map, data_freq);
if (result > max_result)
{
max_result = result;
res = item.name;
}
}
/// If all weights are zero, then we assume that the language is undefined
if (res.empty())
res = "Undefined";
res_data.resize(res_offset + res.size() + 1);
memcpy(&res_data[res_offset], res.data(), res.size());
res_data[res_offset + res.size()] = 0;
res_offset += res.size() + 1;
res_offsets[i] = res_offset;
}
}
};
struct NameDetectProgrammingLanguage
{
static constexpr auto name = "detectProgrammingLanguage";
};
using FunctionDetectProgrammingLanguage = FunctionTextClassificationString<FunctionDetectProgrammingLanguageImpl, NameDetectProgrammingLanguage>;
void registerFunctionDetectProgrammingLanguage(FunctionFactory & factory)
{
factory.registerFunction<FunctionDetectProgrammingLanguage>();
}
}

View File

@ -0,0 +1,122 @@
#pragma once
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
namespace DB
{
/// Functions for text classification with different result types
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int SUPPORT_IS_DISABLED;
}
template <typename Impl, typename Name>
class FunctionTextClassificationString : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr context)
{
if (!context->getSettingsRef().allow_experimental_nlp_functions)
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED,
"Natural language processing function '{}' is experimental. Set `allow_experimental_nlp_functions` setting to enable it", name);
return std::make_shared<FunctionTextClassificationString>();
}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of function {}. Must be String.",
arguments[0]->getName(), getName());
return arguments[0];
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const override
{
const ColumnPtr & column = arguments[0].column;
const ColumnString * col = checkAndGetColumn<ColumnString>(column.get());
if (!col)
throw Exception(
"Illegal column " + arguments[0].column->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
auto col_res = ColumnString::create();
Impl::vector(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets());
return col_res;
}
};
template <typename Impl, typename Name>
class FunctionTextClassificationFloat : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr context)
{
if (!context->getSettingsRef().allow_experimental_nlp_functions)
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED,
"Natural language processing function '{}' is experimental. Set `allow_experimental_nlp_functions` setting to enable it", name);
return std::make_shared<FunctionTextClassificationFloat>();
}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of function {}. Must be String.",
arguments[0]->getName(), getName());
return std::make_shared<DataTypeFloat32>();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const override
{
const ColumnPtr & column = arguments[0].column;
const ColumnString * col = checkAndGetColumn<ColumnString>(column.get());
if (!col)
throw Exception(
"Illegal column " + arguments[0].column->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
auto col_res = ColumnVector<Float32>::create();
ColumnVector<Float32>::Container & vec_res = col_res->getData();
vec_res.resize(col->size());
Impl::vector(col->getChars(), col->getOffsets(), vec_res);
return col_res;
}
};
}

View File

@ -0,0 +1,89 @@
#include <Common/FrequencyHolder.h>
#include <Common/StringUtils/StringUtils.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsTextClassification.h>
#include <unordered_map>
namespace DB
{
/**
* Determines the sentiment of text data.
* Uses a marked-up sentiment dictionary, each word has a tonality ranging from -12 to 6.
* For each text, calculate the average sentiment value of its words and return it in range [-1,1]
*/
struct FunctionDetectTonalityImpl
{
static ALWAYS_INLINE inline Float32 detectTonality(
const UInt8 * str,
const size_t str_len,
const FrequencyHolder::Map & emotional_dict)
{
Float64 weight = 0;
UInt64 count_words = 0;
String word;
/// Select all Russian words from the string
for (size_t ind = 0; ind < str_len; ++ind)
{
/// Split words by whitespaces and punctuation signs
if (isWhitespaceASCII(str[ind]) || isPunctuationASCII(str[ind]))
continue;
while (ind < str_len && !(isWhitespaceASCII(str[ind]) || isPunctuationASCII(str[ind])))
{
word.push_back(str[ind]);
++ind;
}
/// Try to find a russian word in the tonality dictionary
const auto * it = emotional_dict.find(word);
if (it != emotional_dict.end())
{
count_words += 1;
weight += it->getMapped();
}
word.clear();
}
if (!count_words)
return 0;
/// Calculate average value of tonality.
/// Convert values -12..6 to -1..1
if (weight > 0)
return weight / count_words / 6;
else
return weight / count_words / 12;
}
static void vector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
PaddedPODArray<Float32> & res)
{
const auto & emotional_dict = FrequencyHolder::getInstance().getEmotionalDict();
size_t size = offsets.size();
size_t prev_offset = 0;
for (size_t i = 0; i < size; ++i)
{
res[i] = detectTonality(data.data() + prev_offset, offsets[i] - 1 - prev_offset, emotional_dict);
prev_offset = offsets[i];
}
}
};
struct NameDetectTonality
{
static constexpr auto name = "detectTonality";
};
using FunctionDetectTonality = FunctionTextClassificationFloat<FunctionDetectTonalityImpl, NameDetectTonality>;
void registerFunctionDetectTonality(FunctionFactory & factory)
{
factory.registerFunction<FunctionDetectTonality>();
}
}

View File

@ -8,4 +8,5 @@
#cmakedefine01 USE_H3
#cmakedefine01 USE_S2_GEOMETRY
#cmakedefine01 USE_FASTOPS
#cmakedefine01 USE_NLP
#cmakedefine01 USE_HYPERSCAN

View File

@ -39,6 +39,9 @@ void registerFunctionEncodeXMLComponent(FunctionFactory &);
void registerFunctionDecodeXMLComponent(FunctionFactory &);
void registerFunctionExtractTextFromHTML(FunctionFactory &);
void registerFunctionToStringCutToZero(FunctionFactory &);
void registerFunctionDetectCharset(FunctionFactory &);
void registerFunctionDetectTonality(FunctionFactory &);
void registerFunctionDetectProgrammingLanguage(FunctionFactory &);
#if USE_BASE64
void registerFunctionBase64Encode(FunctionFactory &);
@ -50,6 +53,7 @@ void registerFunctionTryBase64Decode(FunctionFactory &);
void registerFunctionStem(FunctionFactory &);
void registerFunctionSynonyms(FunctionFactory &);
void registerFunctionLemmatize(FunctionFactory &);
void registerFunctionsDetectLanguage(FunctionFactory &);
#endif
#if USE_ICU
@ -91,6 +95,9 @@ void registerFunctionsString(FunctionFactory & factory)
registerFunctionDecodeXMLComponent(factory);
registerFunctionExtractTextFromHTML(factory);
registerFunctionToStringCutToZero(factory);
registerFunctionDetectCharset(factory);
registerFunctionDetectTonality(factory);
registerFunctionDetectProgrammingLanguage(factory);
#if USE_BASE64
registerFunctionBase64Encode(factory);
@ -102,6 +109,7 @@ void registerFunctionsString(FunctionFactory & factory)
registerFunctionStem(factory);
registerFunctionSynonyms(factory);
registerFunctionLemmatize(factory);
registerFunctionsDetectLanguage(factory);
#endif
#if USE_ICU

View File

@ -0,0 +1,20 @@
<test>
<settings>
<allow_experimental_nlp_functions>1</allow_experimental_nlp_functions>
</settings>
<preconditions>
<table_exists>hits_100m_single</table_exists>
</preconditions>
<query>SELECT detectLanguage(SearchPhrase) FROM hits_100m_single FORMAT Null</query>
<query>SELECT detectLanguageMixed(SearchPhrase) FROM hits_100m_single FORMAT Null</query>
<query>SELECT detectTonality(SearchPhrase) FROM hits_100m_single FORMAT Null</query>
<!-- Input is not really correct for these functions,
but at least it gives us some idea about their performance -->
<query>SELECT detectProgrammingLanguage(SearchPhrase) FROM hits_100m_single FORMAT Null</query>
<query>SELECT detectLanguageUnknown(SearchPhrase) FROM hits_100m_single FORMAT Null</query>
<query>SELECT detectCharset(SearchPhrase) FROM hits_100m_single FORMAT Null</query>
</test>

View File

@ -0,0 +1,15 @@
ru
en
fr
ja
zh
un
{'ja':0.62,'fr':0.36}
{'ko':0.98}
{}
ISO-8859-1
en
0.465
-0.28823888
0.050505556
C++

View File

@ -0,0 +1,23 @@
-- Tags: no-fasttest
-- Tag no-fasttest: depends on cld2 and nlp-data
SET allow_experimental_nlp_functions = 1;
SELECT detectLanguage('Они сошлись. Волна и камень, Стихи и проза, лед и пламень, Не столь различны меж собой.');
SELECT detectLanguage('Sweet are the uses of adversity which, like the toad, ugly and venomous, wears yet a precious jewel in his head.');
SELECT detectLanguage('A vaincre sans peril, on triomphe sans gloire.');
SELECT detectLanguage('二兎を追う者は一兎をも得ず');
SELECT detectLanguage('有情饮水饱,无情食饭饥。');
SELECT detectLanguage('*****///// _____ ,,,,,,,, .....');
SELECT detectLanguageMixed('二兎を追う者は一兎をも得ず二兎を追う者は一兎をも得ず A vaincre sans peril, on triomphe sans gloire.');
SELECT detectLanguageMixed('어디든 가치가 있는 곳으로 가려면 지름길은 없다');
SELECT detectLanguageMixed('*****///// _____ ,,,,,,,, .....');
SELECT detectCharset('Plain English');
SELECT detectLanguageUnknown('Plain English');
SELECT detectTonality('милая кошка');
SELECT detectTonality('ненависть к людям');
SELECT detectTonality('обычная прогулка по ближайшему парку');
SELECT detectProgrammingLanguage('#include <iostream>');