diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index e7cc084237f..63a454d0ea6 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -410,6 +410,6 @@ if (ENABLE_TESTS AND USE_GTEST) # gtest framework has substandard code target_compile_options(unit_tests_dbms PRIVATE -Wno-zero-as-null-pointer-constant -Wno-undef -Wno-sign-compare -Wno-used-but-marked-unused -Wno-missing-noreturn) - target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_parsers dbms clickhouse_common_zookeeper) + target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils) add_check(unit_tests_dbms) endif () diff --git a/dbms/programs/client/Client.cpp b/dbms/programs/client/Client.cpp index 0bb6cf62f90..df5e8568d21 100644 --- a/dbms/programs/client/Client.cpp +++ b/dbms/programs/client/Client.cpp @@ -608,6 +608,7 @@ private: if (!ends_with_backslash && (ends_with_semicolon || has_vertical_output_suffix || (!config().has("multiline") && !hasDataInSTDIN()))) { + // TODO: should we do sensitive data masking on client too? History file can be source of secret leaks. if (input != prev_input) { /// Replace line breaks with spaces to prevent the following problem. @@ -1027,13 +1028,17 @@ private: while (true) { Block block = async_block_input->read(); - connection->sendData(block); - processed_rows += block.rows(); /// Check if server send Log packet + receiveLogs(); + + /// Check if server send Exception packet auto packet_type = connection->checkPacket(); - if (packet_type && *packet_type == Protocol::Server::Log) - receiveAndProcessPacket(); + if (packet_type && *packet_type == Protocol::Server::Exception) + return; + + connection->sendData(block); + processed_rows += block.rows(); if (!block) break; @@ -1250,6 +1255,17 @@ private: } } + /// Process Log packets, used when inserting data by blocks + void receiveLogs() + { + auto packet_type = connection->checkPacket(); + + while (packet_type && *packet_type == Protocol::Server::Log) + { + receiveAndProcessPacket(); + packet_type = connection->checkPacket(); + } + } void initBlockOutputStream(const Block & block) { diff --git a/dbms/programs/local/LocalServer.cpp b/dbms/programs/local/LocalServer.cpp index 1844c037784..f4eac1baec2 100644 --- a/dbms/programs/local/LocalServer.cpp +++ b/dbms/programs/local/LocalServer.cpp @@ -74,6 +74,7 @@ void LocalServer::initialize(Poco::Util::Application & self) if (config().has("logger") || config().has("logger.level") || config().has("logger.log")) { + // sensitive data rules are not used here buildLoggers(config(), logger()); } else diff --git a/dbms/programs/odbc-bridge/ODBCBridge.cpp b/dbms/programs/odbc-bridge/ODBCBridge.cpp index cf265eb6abb..214d9f75328 100644 --- a/dbms/programs/odbc-bridge/ODBCBridge.cpp +++ b/dbms/programs/odbc-bridge/ODBCBridge.cpp @@ -124,6 +124,7 @@ void ODBCBridge::initialize(Application & self) config().setString("logger", "ODBCBridge"); buildLoggers(config(), logger()); + log = &logger(); hostname = config().getString("listen-host", "localhost"); port = config().getUInt("http-port"); @@ -162,6 +163,12 @@ int ODBCBridge::main(const std::vector & /*args*/) context = std::make_shared(Context::createGlobal()); context->makeGlobalContext(); + if (config().has("query_masking_rules")) + { + context->setSensitiveDataMasker(std::make_unique(config(), "query_masking_rules")); + setLoggerSensitiveDataMasker(logger(), context->getSensitiveDataMasker()); + } + auto server = Poco::Net::HTTPServer( new HandlerFactory("ODBCRequestHandlerFactory-factory", keep_alive_timeout, context), server_pool, socket, http_params); server.start(); diff --git a/dbms/programs/server/Server.cpp b/dbms/programs/server/Server.cpp index f10dc07ab56..82f50f26569 100644 --- a/dbms/programs/server/Server.cpp +++ b/dbms/programs/server/Server.cpp @@ -278,7 +278,11 @@ int Server::main(const std::vector & /*args*/) * table engines could use Context on destroy. */ LOG_INFO(log, "Shutting down storages."); + + // global_context is the owner of sensitive_data_masker, which will be destoyed after global_context->shutdown() call + setLoggerSensitiveDataMasker(logger(), nullptr); global_context->shutdown(); + LOG_DEBUG(log, "Shutted down storages."); /** Explicitly destroy Context. It is more convenient than in destructor of Server, because logger is still available. @@ -407,6 +411,12 @@ int Server::main(const std::vector & /*args*/) /// Initialize main config reloader. std::string include_from_path = config().getString("include_from", "/etc/metrika.xml"); + + if (config().has("query_masking_rules")) + { + global_context->setSensitiveDataMasker(std::make_unique(config(), "query_masking_rules")); + } + auto main_config_reloader = std::make_unique(config_path, include_from_path, config().getString("path", ""), @@ -416,6 +426,10 @@ int Server::main(const std::vector & /*args*/) { setTextLog(global_context->getTextLog()); buildLoggers(*config, logger()); + if (auto masker = global_context->getSensitiveDataMasker()) + { + setLoggerSensitiveDataMasker(logger(), masker); + } global_context->setClustersConfig(config); global_context->setMacros(std::make_unique(*config, "macros")); }, diff --git a/dbms/programs/server/TCPHandler.cpp b/dbms/programs/server/TCPHandler.cpp index 77b359f2763..e9cd7f04c8a 100644 --- a/dbms/programs/server/TCPHandler.cpp +++ b/dbms/programs/server/TCPHandler.cpp @@ -633,6 +633,13 @@ void TCPHandler::processTablesStatusRequest() response.write(*out, client_revision); } +void TCPHandler::receiveUnexpectedTablesStatusRequest() +{ + TablesStatusRequest skip_request; + skip_request.read(*in, client_revision); + + throw NetException("Unexpected packet TablesStatusRequest received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); +} void TCPHandler::sendProfileInfo(const BlockStreamProfileInfo & info) { @@ -722,6 +729,23 @@ void TCPHandler::receiveHello() } +void TCPHandler::receiveUnexpectedHello() +{ + UInt64 skip_uint_64; + String skip_string; + + readStringBinary(skip_string, *in); + readVarUInt(skip_uint_64, *in); + readVarUInt(skip_uint_64, *in); + readVarUInt(skip_uint_64, *in); + readStringBinary(skip_string, *in); + readStringBinary(skip_string, *in); + readStringBinary(skip_string, *in); + + throw NetException("Unexpected packet Hello received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); +} + + void TCPHandler::sendHello() { writeVarUInt(Protocol::Server::Hello, *out); @@ -744,19 +768,19 @@ bool TCPHandler::receivePacket() UInt64 packet_type = 0; readVarUInt(packet_type, *in); -// std::cerr << "Packet: " << packet_type << std::endl; +// std::cerr << "Server got packet: " << Protocol::Client::toString(packet_type) << "\n"; switch (packet_type) { case Protocol::Client::Query: if (!state.empty()) - throw NetException("Unexpected packet Query received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); + receiveUnexpectedQuery(); receiveQuery(); return true; case Protocol::Client::Data: if (state.empty()) - throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); + receiveUnexpectedData(); return receiveData(); case Protocol::Client::Ping: @@ -768,12 +792,11 @@ bool TCPHandler::receivePacket() return false; case Protocol::Client::Hello: - throw Exception("Unexpected packet " + String(Protocol::Client::toString(packet_type)) + " received from client", - ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); + receiveUnexpectedHello(); case Protocol::Client::TablesStatusRequest: if (!state.empty()) - throw NetException("Unexpected packet TablesStatusRequest received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); + receiveUnexpectedTablesStatusRequest(); processTablesStatusRequest(); out->next(); return false; @@ -842,6 +865,26 @@ void TCPHandler::receiveQuery() readStringBinary(state.query, *in); } +void TCPHandler::receiveUnexpectedQuery() +{ + UInt64 skip_uint_64; + String skip_string; + + readStringBinary(skip_string, *in); + + ClientInfo & skip_client_info = query_context->getClientInfo(); + if (client_revision >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) + skip_client_info.read(*in, client_revision); + + Settings & skip_settings = query_context->getSettingsRef(); + skip_settings.deserialize(*in); + + readVarUInt(skip_uint_64, *in); + readVarUInt(skip_uint_64, *in); + readStringBinary(skip_string, *in); + + throw NetException("Unexpected packet Query received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); +} bool TCPHandler::receiveData() { @@ -880,6 +923,27 @@ bool TCPHandler::receiveData() return false; } +void TCPHandler::receiveUnexpectedData() +{ + String skip_external_table_name; + readStringBinary(skip_external_table_name, *in); + + std::shared_ptr maybe_compressed_in; + + if (last_block_in.compression == Protocol::Compression::Enable) + maybe_compressed_in = std::make_shared(*in); + else + maybe_compressed_in = in; + + auto skip_block_in = std::make_shared( + *maybe_compressed_in, + last_block_in.header, + client_revision, + !connection_context.getSettingsRef().low_cardinality_allow_in_native_format); + + Block skip_block = skip_block_in->read(); + throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); +} void TCPHandler::initBlockInput() { @@ -894,6 +958,9 @@ void TCPHandler::initBlockInput() if (state.io.out) header = state.io.out->getHeader(); + last_block_in.header = header; + last_block_in.compression = state.compression; + state.block_in = std::make_shared( *state.maybe_compressed_in, header, diff --git a/dbms/programs/server/TCPHandler.h b/dbms/programs/server/TCPHandler.h index 3cacd5fae95..fca75dfd832 100644 --- a/dbms/programs/server/TCPHandler.h +++ b/dbms/programs/server/TCPHandler.h @@ -82,6 +82,13 @@ struct QueryState }; +struct LastBlockInputParameters +{ + Protocol::Compression compression = Protocol::Compression::Disable; + Block header; +}; + + class TCPHandler : public Poco::Net::TCPServerConnection { public: @@ -126,6 +133,9 @@ private: /// At the moment, only one ongoing query in the connection is supported at a time. QueryState state; + /// Last block input parameters are saved to be able to receive unexpected data packet sent after exception. + LastBlockInputParameters last_block_in; + CurrentMetrics::Increment metric_increment{CurrentMetrics::TCPConnection}; /// It is the name of the server that will be sent to the client. @@ -139,6 +149,11 @@ private: bool receiveData(); void readData(const Settings & global_settings); + [[noreturn]] void receiveUnexpectedData(); + [[noreturn]] void receiveUnexpectedQuery(); + [[noreturn]] void receiveUnexpectedHello(); + [[noreturn]] void receiveUnexpectedTablesStatusRequest(); + /// Process INSERT query void processInsertQuery(const Settings & global_settings); diff --git a/dbms/programs/server/config.d/query_masking_rules.xml b/dbms/programs/server/config.d/query_masking_rules.xml new file mode 100644 index 00000000000..f919523472c --- /dev/null +++ b/dbms/programs/server/config.d/query_masking_rules.xml @@ -0,0 +1,19 @@ + + + + + + + profanity + (?i:shit) + substance + + + + TOPSECRET.TOPSECRET + [hidden] + + + diff --git a/dbms/programs/server/config.xml b/dbms/programs/server/config.xml index 814b7dded3c..d8fcd9b0c9e 100644 --- a/dbms/programs/server/config.xml +++ b/dbms/programs/server/config.xml @@ -439,6 +439,20 @@ --> /var/lib/clickhouse/format_schemas/ + + + diff --git a/dbms/src/Client/Connection.cpp b/dbms/src/Client/Connection.cpp index 9cdda9fdf0d..a6e533d8dd2 100644 --- a/dbms/src/Client/Connection.cpp +++ b/dbms/src/Client/Connection.cpp @@ -589,7 +589,7 @@ Connection::Packet Connection::receivePacket() } //LOG_TRACE(log_wrapper.get(), "Receiving packet " << res.type << " " << Protocol::Server::toString(res.type)); - + //std::cerr << "Client got packet: " << Protocol::Server::toString(res.type) << "\n"; switch (res.type) { case Protocol::Server::Data: [[fallthrough]]; diff --git a/dbms/src/Common/Allocator.h b/dbms/src/Common/Allocator.h index dc1d6ff5df9..ad5b0318c91 100644 --- a/dbms/src/Common/Allocator.h +++ b/dbms/src/Common/Allocator.h @@ -3,8 +3,6 @@ #include #ifdef NDEBUG - /// If set to 1 - randomize memory mappings manually (address space layout randomization) to reproduce more memory stomping bugs. - /// Note that Linux doesn't do it by default. This may lead to worse TLB performance. #define ALLOCATOR_ASLR 0 #else #define ALLOCATOR_ASLR 1 @@ -38,23 +36,27 @@ #define MAP_ANONYMOUS MAP_ANON #endif - -/** Many modern allocators (for example, tcmalloc) do not do a mremap for realloc, - * even in case of large enough chunks of memory. - * Although this allows you to increase performance and reduce memory consumption during realloc. +/** + * Many modern allocators (for example, tcmalloc) do not do a mremap for + * realloc, even in case of large enough chunks of memory. Although this allows + * you to increase performance and reduce memory consumption during realloc. * To fix this, we do mremap manually if the chunk of memory is large enough. - * The threshold (64 MB) is chosen quite large, since changing the address space is - * very slow, especially in the case of a large number of threads. - * We expect that the set of operations mmap/something to do/mremap can only be performed about 1000 times per second. + * The threshold (64 MB) is chosen quite large, since changing the address + * space is very slow, especially in the case of a large number of threads. We + * expect that the set of operations mmap/something to do/mremap can only be + * performed about 1000 times per second. * - * PS. This is also required, because tcmalloc can not allocate a chunk of memory greater than 16 GB. + * P.S. This is also required, because tcmalloc can not allocate a chunk of + * memory greater than 16 GB. */ #ifdef NDEBUG static constexpr size_t MMAP_THRESHOLD = 64 * (1ULL << 20); #else - /// In debug build, use small mmap threshold to reproduce more memory stomping bugs. - /// Along with ASLR it will hopefully detect more issues than ASan. - /// The program may fail due to the limit on number of memory mappings. + /** + * In debug build, use small mmap threshold to reproduce more memory + * stomping bugs. Along with ASLR it will hopefully detect more issues than + * ASan. The program may fail due to the limit on number of memory mappings. + */ static constexpr size_t MMAP_THRESHOLD = 4096; #endif @@ -72,25 +74,6 @@ namespace ErrorCodes } } -namespace AllocatorHints -{ -struct DefaultHint -{ - void * mmap_hint() - { - return nullptr; - } -}; - -struct RandomHint -{ - void * mmap_hint() - { - return reinterpret_cast(std::uniform_int_distribution(0x100000000000UL, 0x700000000000UL)(thread_local_rng)); - } -}; -} - /** Responsible for allocating / freeing memory. Used, for example, in PODArray, Arena. * Also used in hash tables. * The interface is different from std::allocator @@ -98,16 +81,12 @@ struct RandomHint * - passing the size into the `free` method; * - by the presence of the `alignment` argument; * - the possibility of zeroing memory (used in hash tables); - * - hint class for mmap + * - random hint address for mmap * - mmap_threshold for using mmap less or more */ -template -class AllocatorWithHint : Hint +template +class Allocator { -protected: - static constexpr bool clear_memory = clear_memory_; - static constexpr size_t small_memory_threshold = mmap_threshold; - public: /// Allocate memory range. void * alloc(size_t size, size_t alignment = 0) @@ -134,7 +113,8 @@ public: /// nothing to do. /// BTW, it's not possible to change alignment while doing realloc. } - else if (old_size < mmap_threshold && new_size < mmap_threshold && alignment <= MALLOC_MIN_ALIGNMENT) + else if (old_size < MMAP_THRESHOLD && new_size < MMAP_THRESHOLD + && alignment <= MALLOC_MIN_ALIGNMENT) { /// Resize malloc'd memory region with no special alignment requirement. CurrentMemoryTracker::realloc(old_size, new_size); @@ -148,19 +128,20 @@ public: if (new_size > old_size) memset(reinterpret_cast(buf) + old_size, 0, new_size - old_size); } - else if (old_size >= mmap_threshold && new_size >= mmap_threshold) + else if (old_size >= MMAP_THRESHOLD && new_size >= MMAP_THRESHOLD) { /// Resize mmap'd memory region. CurrentMemoryTracker::realloc(old_size, new_size); // On apple and freebsd self-implemented mremap used (common/mremap.h) - buf = clickhouse_mremap(buf, old_size, new_size, MREMAP_MAYMOVE, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + buf = clickhouse_mremap(buf, old_size, new_size, MREMAP_MAYMOVE, + PROT_READ | PROT_WRITE, mmap_flags, -1, 0); if (MAP_FAILED == buf) DB::throwFromErrno("Allocator: Cannot mremap memory chunk from " + formatReadableSizeWithBinarySuffix(old_size) + " to " + formatReadableSizeWithBinarySuffix(new_size) + ".", DB::ErrorCodes::CANNOT_MREMAP); /// No need for zero-fill, because mmap guarantees it. } - else if (new_size < small_memory_threshold) + else if (new_size < MMAP_THRESHOLD) { /// Small allocs that requires a copy. Assume there's enough memory in system. Call CurrentMemoryTracker once. CurrentMemoryTracker::realloc(old_size, new_size); @@ -189,18 +170,30 @@ protected: return 0; } + static constexpr bool clear_memory = clear_memory_; + + // Freshly mmapped pages are copy-on-write references to a global zero page. + // On the first write, a page fault occurs, and an actual writable page is + // allocated. If we are going to use this memory soon, such as when resizing + // hash tables, it makes sense to pre-fault the pages by passing + // MAP_POPULATE to mmap(). This takes some time, but should be faster + // overall than having a hot loop interrupted by page faults. + static constexpr int mmap_flags = MAP_PRIVATE | MAP_ANONYMOUS + | (mmap_populate ? MAP_POPULATE : 0); + private: void * allocNoTrack(size_t size, size_t alignment) { void * buf; - if (size >= mmap_threshold) + if (size >= MMAP_THRESHOLD) { if (alignment > MMAP_MIN_ALIGNMENT) throw DB::Exception("Too large alignment " + formatReadableSizeWithBinarySuffix(alignment) + ": more than page size when allocating " + formatReadableSizeWithBinarySuffix(size) + ".", DB::ErrorCodes::BAD_ARGUMENTS); - buf = mmap(Hint::mmap_hint(), size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + buf = mmap(getMmapHint(), size, PROT_READ | PROT_WRITE, + mmap_flags, -1, 0); if (MAP_FAILED == buf) DB::throwFromErrno("Allocator: Cannot mmap " + formatReadableSizeWithBinarySuffix(size) + ".", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY); @@ -235,7 +228,7 @@ private: void freeNoTrack(void * buf, size_t size) { - if (size >= mmap_threshold) + if (size >= MMAP_THRESHOLD) { if (0 != munmap(buf, size)) DB::throwFromErrno("Allocator: Cannot munmap " + formatReadableSizeWithBinarySuffix(size) + ".", DB::ErrorCodes::CANNOT_MUNMAP); @@ -245,15 +238,22 @@ private: ::free(buf); } } -}; -#if ALLOCATOR_ASLR -template -using Allocator = AllocatorWithHint; +#ifndef NDEBUG + /// In debug builds, request mmap() at random addresses (a kind of ASLR), to + /// reproduce more memory stomping bugs. Note that Linux doesn't do it by + /// default. This may lead to worse TLB performance. + void * getMmapHint() + { + return reinterpret_cast(std::uniform_int_distribution(0x100000000000UL, 0x700000000000UL)(thread_local_rng)); + } #else -template -using Allocator = AllocatorWithHint; + void * getMmapHint() + { + return nullptr; + } #endif +}; /** When using AllocatorWithStackMemory, located on the stack, * GCC 4.9 mistakenly assumes that we can call `free` from a pointer to the stack. diff --git a/dbms/src/Common/FieldVisitors.cpp b/dbms/src/Common/FieldVisitors.cpp index 9a437d5ffe6..c5ce10c0db4 100644 --- a/dbms/src/Common/FieldVisitors.cpp +++ b/dbms/src/Common/FieldVisitors.cpp @@ -130,7 +130,7 @@ String FieldVisitorToString::operator() (const DecimalField & x) con String FieldVisitorToString::operator() (const UInt128 & x) const { return formatQuoted(UUID(x)); } String FieldVisitorToString::operator() (const AggregateFunctionStateData & x) const { - return "(" + formatQuoted(x.name) + ")" + formatQuoted(x.data); + return formatQuoted(x.data); } String FieldVisitorToString::operator() (const Array & x) const diff --git a/dbms/src/Common/HashTable/HashTableAllocator.h b/dbms/src/Common/HashTable/HashTableAllocator.h index eccf29c8b42..99f9c979685 100644 --- a/dbms/src/Common/HashTable/HashTableAllocator.h +++ b/dbms/src/Common/HashTable/HashTableAllocator.h @@ -3,7 +3,12 @@ #include -using HashTableAllocator = Allocator; +/** + * We are going to use the entire memory we allocated when resizing a hash + * table, so it makes sense to pre-fault the pages so that page faults don't + * interrupt the resize loop. Set the allocator parameter accordingly. + */ +using HashTableAllocator = Allocator; template using HashTableAllocatorWithStackMemory = AllocatorWithStackMemory; diff --git a/dbms/src/Common/ProfileEvents.cpp b/dbms/src/Common/ProfileEvents.cpp index 67303b085f4..947e3890078 100644 --- a/dbms/src/Common/ProfileEvents.cpp +++ b/dbms/src/Common/ProfileEvents.cpp @@ -46,6 +46,8 @@ M(NetworkSendElapsedMicroseconds, "") \ M(ThrottlerSleepMicroseconds, "Total time a query was sleeping to conform the 'max_network_bandwidth' setting.") \ \ + M(QueryMaskingRulesMatch, "Number of times query masking rules was successfully matched.") \ + \ M(ReplicatedPartFetches, "Number of times a data part was downloaded from replica of a ReplicatedMergeTree table.") \ M(ReplicatedPartFailedFetches, "") \ M(ObsoleteReplicatedParts, "") \ diff --git a/dbms/src/Common/SensitiveDataMasker.cpp b/dbms/src/Common/SensitiveDataMasker.cpp new file mode 100644 index 00000000000..41e14aabb3c --- /dev/null +++ b/dbms/src/Common/SensitiveDataMasker.cpp @@ -0,0 +1,164 @@ +#include "SensitiveDataMasker.h" + +#include +#include +#include + +#include +#include + +#include + +#include + +#include +#include + +#ifndef NDEBUG +# include +#endif + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_COMPILE_REGEXP; + extern const int NO_ELEMENTS_IN_CONFIG; + extern const int INVALID_CONFIG_PARAMETER; +} + +class SensitiveDataMasker::MaskingRule +{ +private: + const std::string name; + const std::string replacement_string; + const std::string regexp_string; + + const RE2 regexp; + const re2::StringPiece replacement; + + mutable std::atomic matches_count = 0; + +public: + //* TODO: option with hyperscan? https://software.intel.com/en-us/articles/why-and-how-to-replace-pcre-with-hyperscan + // re2::set should also work quite fast, but it doesn't return the match position, only which regexp was matched + + MaskingRule(const std::string & name_, const std::string & regexp_string_, const std::string & replacement_string_) + : name(name_) + , replacement_string(replacement_string_) + , regexp_string(regexp_string_) + , regexp(regexp_string, RE2::Quiet) + , replacement(replacement_string) + { + if (!regexp.ok()) + throw DB::Exception( + "SensitiveDataMasker: cannot compile re2: " + regexp_string_ + ", error: " + regexp.error() + + ". Look at https://github.com/google/re2/wiki/Syntax for reference.", + DB::ErrorCodes::CANNOT_COMPILE_REGEXP); + } + + uint64_t apply(std::string & data) const + { + auto m = RE2::GlobalReplace(&data, regexp, replacement); + matches_count += m; + return m; + } + + const std::string & getName() const { return name; } + const std::string & getReplacementString() const { return replacement_string; } + uint64_t getMatchesCount() const { return matches_count; } +}; + + +SensitiveDataMasker::SensitiveDataMasker(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix) +{ + Poco::Util::AbstractConfiguration::Keys keys; + config.keys(config_prefix, keys); + Logger * logger = &Logger::get("SensitiveDataMaskerConfigRead"); + + std::set used_names; + + for (const auto & rule : keys) + { + if (startsWith(rule, "rule")) + { + auto rule_config_prefix = config_prefix + "." + rule; + + auto rule_name = config.getString(rule_config_prefix + ".name", rule_config_prefix); + + if (!used_names.insert(rule_name).second) + { + throw Exception( + "query_masking_rules configuration contains more than one rule named '" + rule_name + "'.", + ErrorCodes::INVALID_CONFIG_PARAMETER); + } + + auto regexp = config.getString(rule_config_prefix + ".regexp", ""); + + if (regexp.empty()) + { + throw Exception( + "query_masking_rules configuration, rule '" + rule_name + "' has no node or is empty.", + ErrorCodes::NO_ELEMENTS_IN_CONFIG); + } + + auto replace = config.getString(rule_config_prefix + ".replace", "******"); + + try + { + addMaskingRule(rule_name, regexp, replace); + } + catch (DB::Exception & e) + { + e.addMessage("while adding query masking rule '" + rule_name + "'."); + throw; + } + } + else + { + LOG_WARNING(logger, "Unused param " << config_prefix << '.' << rule); + } + } + + auto rules_count = rulesCount(); + if (rules_count > 0) + { + LOG_INFO(logger, rules_count << " query masking rules loaded."); + } +} + +SensitiveDataMasker::~SensitiveDataMasker() {} + +void SensitiveDataMasker::addMaskingRule( + const std::string & name, const std::string & regexp_string, const std::string & replacement_string) +{ + all_masking_rules.push_back(std::make_unique(name, regexp_string, replacement_string)); +} + + +size_t SensitiveDataMasker::wipeSensitiveData(std::string & data) const +{ + size_t matches = 0; + for (auto & rule : all_masking_rules) + matches += rule->apply(data); + return matches; +} + +#ifndef NDEBUG +void SensitiveDataMasker::printStats() +{ + for (auto & rule : all_masking_rules) + { + std::cout << rule->getName() << " (replacement to " << rule->getReplacementString() << ") matched " << rule->getMatchesCount() + << " times" << std::endl; + } +} +#endif + +size_t SensitiveDataMasker::rulesCount() const +{ + return all_masking_rules.size(); +} + +} diff --git a/dbms/src/Common/SensitiveDataMasker.h b/dbms/src/Common/SensitiveDataMasker.h new file mode 100644 index 00000000000..b7a7b12ee93 --- /dev/null +++ b/dbms/src/Common/SensitiveDataMasker.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + + +namespace Poco +{ +namespace Util +{ + class AbstractConfiguration; +} +} + +namespace DB +{ +class SensitiveDataMasker +{ +private: + class MaskingRule; + std::vector> all_masking_rules; + +public: + SensitiveDataMasker(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix); + ~SensitiveDataMasker(); + + /// Returns the number of matched rules. + size_t wipeSensitiveData(std::string & data) const; + + /// Used in tests. + void addMaskingRule(const std::string & name, const std::string & regexp_string, const std::string & replacement_string); + +#ifndef NDEBUG + void printStats(); +#endif + + size_t rulesCount() const; +}; + +}; diff --git a/dbms/src/Common/tests/gtest_sensitive_data_masker.cpp b/dbms/src/Common/tests/gtest_sensitive_data_masker.cpp new file mode 100644 index 00000000000..004237aa57f --- /dev/null +++ b/dbms/src/Common/tests/gtest_sensitive_data_masker.cpp @@ -0,0 +1,225 @@ +#include +#include +#include +#include +#include + +#pragma GCC diagnostic ignored "-Wsign-compare" +#ifdef __clang__ +# pragma clang diagnostic ignored "-Wzero-as-null-pointer-constant" +# pragma clang diagnostic ignored "-Wundef" +#endif + +#include + + +namespace DB +{ +namespace ErrorCodes +{ +extern const int CANNOT_COMPILE_REGEXP; +extern const int NO_ELEMENTS_IN_CONFIG; +extern const int INVALID_CONFIG_PARAMETER; +} +}; + + +TEST(Common, SensitiveDataMasker) +{ + + Poco::AutoPtr empty_xml_config = new Poco::Util::XMLConfiguration(); + DB::SensitiveDataMasker masker(*empty_xml_config , ""); + masker.addMaskingRule("all a letters", "a+", "--a--"); + masker.addMaskingRule("all b letters", "b+", "--b--"); + masker.addMaskingRule("all d letters", "d+", "--d--"); + masker.addMaskingRule("all x letters", "x+", "--x--"); + masker.addMaskingRule("rule \"d\" result", "--d--", "*****"); // RE2 regexps are applied one-by-one in order + std::string x = "aaaaaaaaaaaaa bbbbbbbbbb cccc aaaaaaaaaaaa d "; + EXPECT_EQ(masker.wipeSensitiveData(x), 5); + EXPECT_EQ(x, "--a-- --b-- cccc --a-- ***** "); +#ifndef NDEBUG + masker.printStats(); +#endif + EXPECT_EQ(masker.wipeSensitiveData(x), 3); + EXPECT_EQ(x, "----a---- ----b---- cccc ----a---- ***** "); +#ifndef NDEBUG + masker.printStats(); +#endif + + DB::SensitiveDataMasker masker2(*empty_xml_config , ""); + masker2.addMaskingRule("hide root password", "qwerty123", "******"); + masker2.addMaskingRule("hide SSN", "[0-9]{3}-[0-9]{2}-[0-9]{4}", "000-00-0000"); + masker2.addMaskingRule("hide email", "[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,4}", "hidden@hidden.test"); + + std::string query = "SELECT id FROM mysql('localhost:3308', 'database', 'table', 'root', 'qwerty123') WHERE ssn='123-45-6789' or " + "email='JonhSmith@secret.domain.test'"; + EXPECT_EQ(masker2.wipeSensitiveData(query), 3); + EXPECT_EQ( + query, + "SELECT id FROM mysql('localhost:3308', 'database', 'table', 'root', '******') WHERE " + "ssn='000-00-0000' or email='hidden@hidden.test'"); + +#ifndef NDEBUG + // simple benchmark + auto start = std::chrono::high_resolution_clock::now(); + constexpr unsigned long int iterations = 200000; + for (int i = 0; i < iterations; ++i) + { + std::string query2 = "SELECT id FROM mysql('localhost:3308', 'database', 'table', 'root', 'qwerty123') WHERE ssn='123-45-6789' or " + "email='JonhSmith@secret.domain.test'"; + masker2.wipeSensitiveData(query2); + } + auto finish = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = finish - start; + std::cout << "Elapsed time: " << elapsed.count() << "s per " << iterations <<" calls (" << elapsed.count() * 1000000 / iterations << "µs per call)" + << std::endl; + // I have: "Elapsed time: 3.44022s per 200000 calls (17.2011µs per call)" + masker2.printStats(); +#endif + + DB::SensitiveDataMasker maskerbad(*empty_xml_config , ""); + + // gtest has not good way to check exception content, so just do it manually (see https://github.com/google/googletest/issues/952 ) + try + { + maskerbad.addMaskingRule("bad regexp", "**", ""); + ADD_FAILURE() << "addMaskingRule() should throw an error" << std::endl; + } + catch (const DB::Exception & e) + { + EXPECT_EQ( + std::string(e.what()), + "SensitiveDataMasker: cannot compile re2: **, error: no argument for repetition operator: *. Look at " + "https://github.com/google/re2/wiki/Syntax for reference."); + EXPECT_EQ(e.code(), DB::ErrorCodes::CANNOT_COMPILE_REGEXP); + } + /* catch (...) { // not needed, gtest will react unhandled exception + FAIL() << "ERROR: Unexpected exception thrown: " << std::current_exception << std::endl; // std::current_exception is part of C++11x + } */ + + EXPECT_EQ(maskerbad.rulesCount(), 0); + EXPECT_EQ(maskerbad.wipeSensitiveData(x), 0); + + { + std::istringstream xml_isteam(R"END( + + + + hide SSN + [0-9]{3}-[0-9]{2}-[0-9]{4} + 000-00-0000 + + + hide root password + qwerty123 + + + (?i)Ivan + John + + + (?i)Petrov + Doe + + + hide email + (?i)[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,4} + hidden@hidden.test + + + remove selects to bad_words table + ^.*bad_words.*$ + [QUERY IS CENSORED] + + +)END"); + + Poco::AutoPtr xml_config = new Poco::Util::XMLConfiguration(xml_isteam); + DB::SensitiveDataMasker masker_xml_based(*xml_config, "query_masking_rules"); + std::string top_secret = "The e-mail of IVAN PETROV is kotik1902@sdsdf.test, and the password is qwerty123"; + EXPECT_EQ(masker_xml_based.wipeSensitiveData(top_secret), 4); + EXPECT_EQ(top_secret, "The e-mail of John Doe is hidden@hidden.test, and the password is ******"); + + top_secret = "SELECT * FROM bad_words"; + EXPECT_EQ(masker_xml_based.wipeSensitiveData(top_secret), 1); + EXPECT_EQ(top_secret, "[QUERY IS CENSORED]"); + +#ifndef NDEBUG + masker_xml_based.printStats(); +#endif + } + + try + { + std::istringstream xml_isteam_bad(R"END( + + + + test + abc + + + test + abc + + +)END"); + Poco::AutoPtr xml_config = new Poco::Util::XMLConfiguration(xml_isteam_bad); + DB::SensitiveDataMasker masker_xml_based_exception_check(*xml_config, "query_masking_rules"); + + ADD_FAILURE() << "XML should throw an error on bad XML" << std::endl; + } + catch (const DB::Exception & e) + { + EXPECT_EQ( + std::string(e.what()), + "query_masking_rules configuration contains more than one rule named 'test'."); + EXPECT_EQ(e.code(), DB::ErrorCodes::INVALID_CONFIG_PARAMETER); + } + + try + { + std::istringstream xml_isteam_bad(R"END( + + + test + +)END"); + + Poco::AutoPtr xml_config = new Poco::Util::XMLConfiguration(xml_isteam_bad); + DB::SensitiveDataMasker masker_xml_based_exception_check(*xml_config, "query_masking_rules"); + + ADD_FAILURE() << "XML should throw an error on bad XML" << std::endl; + } + catch (const DB::Exception & e) + { + EXPECT_EQ( + std::string(e.what()), + "query_masking_rules configuration, rule 'test' has no node or is empty."); + EXPECT_EQ(e.code(), DB::ErrorCodes::NO_ELEMENTS_IN_CONFIG); + } + + try + { + std::istringstream xml_isteam_bad(R"END( + + + test())( + +)END"); + + Poco::AutoPtr xml_config = new Poco::Util::XMLConfiguration(xml_isteam_bad); + DB::SensitiveDataMasker masker_xml_based_exception_check(*xml_config, "query_masking_rules"); + + ADD_FAILURE() << "XML should throw an error on bad XML" << std::endl; + } + catch (const DB::Exception & e) + { + EXPECT_EQ( + std::string(e.message()), + "SensitiveDataMasker: cannot compile re2: ())(, error: missing ): ())(. Look at https://github.com/google/re2/wiki/Syntax for reference.: while adding query masking rule 'test'." + ); + EXPECT_EQ(e.code(), DB::ErrorCodes::CANNOT_COMPILE_REGEXP); + } + +} diff --git a/dbms/src/Core/Settings.h b/dbms/src/Core/Settings.h index 5cb92038977..5de34ea5efe 100644 --- a/dbms/src/Core/Settings.h +++ b/dbms/src/Core/Settings.h @@ -59,7 +59,7 @@ struct Settings : public SettingsCollection M(SettingMilliseconds, connect_timeout_with_failover_ms, DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_MS, "Connection timeout for selecting first healthy replica.") \ M(SettingSeconds, receive_timeout, DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC, "") \ M(SettingSeconds, send_timeout, DBMS_DEFAULT_SEND_TIMEOUT_SEC, "") \ - M(SettingSeconds, tcp_keep_alive_timeout, 0, "") \ + M(SettingSeconds, tcp_keep_alive_timeout, 0, "The time in seconds the connection needs to remain idle before TCP starts sending keepalive probes") \ M(SettingMilliseconds, queue_max_wait_ms, 0, "The wait time in the request queue, if the number of concurrent requests exceeds the maximum.") \ M(SettingMilliseconds, connection_pool_max_wait_ms, 0, "The wait time when connection pool is full.") \ M(SettingMilliseconds, replace_running_query_max_wait_ms, 5000, "The wait time for running query with the same query_id to finish when setting 'replace_running_query' is active.") \ @@ -224,6 +224,7 @@ struct Settings : public SettingsCollection M(SettingSeconds, http_receive_timeout, DEFAULT_HTTP_READ_BUFFER_TIMEOUT, "HTTP receive timeout") \ M(SettingBool, optimize_throw_if_noop, false, "If setting is enabled and OPTIMIZE query didn't actually assign a merge then an explanatory exception is thrown") \ M(SettingBool, use_index_for_in_with_subqueries, true, "Try using an index if there is a subquery or a table expression on the right side of the IN operator.") \ + M(SettingBool, joined_subquery_requires_alias, false, "Force joined subqueries to have aliases for correct name qualification.") \ M(SettingBool, empty_result_for_aggregation_by_empty_set, false, "Return empty result when aggregating without keys on empty set.") \ M(SettingBool, allow_distributed_ddl, true, "If it is set to true, then a user is allowed to executed distributed DDL queries.") \ M(SettingUInt64, odbc_max_field_size, 1024, "Max size of filed can be read from ODBC dictionary. Long strings are truncated.") \ diff --git a/dbms/src/DataStreams/TTLBlockInputStream.cpp b/dbms/src/DataStreams/TTLBlockInputStream.cpp index e2a3a7ca03b..e98ce4eb1b7 100644 --- a/dbms/src/DataStreams/TTLBlockInputStream.cpp +++ b/dbms/src/DataStreams/TTLBlockInputStream.cpp @@ -34,7 +34,7 @@ TTLBlockInputStream::TTLBlockInputStream( ASTPtr default_expr_list = std::make_shared(); for (const auto & [name, ttl_info] : old_ttl_infos.columns_ttl) { - if (ttl_info.min <= current_time) + if (force || isTTLExpired(ttl_info.min)) { new_ttl_infos.columns_ttl.emplace(name, MergeTreeDataPart::TTLInfo{}); empty_columns.emplace(name); @@ -51,7 +51,7 @@ TTLBlockInputStream::TTLBlockInputStream( new_ttl_infos.columns_ttl.emplace(name, ttl_info); } - if (old_ttl_infos.table_ttl.min > current_time) + if (!force && !isTTLExpired(old_ttl_infos.table_ttl.min)) new_ttl_infos.table_ttl = old_ttl_infos.table_ttl; if (!default_expr_list->children.empty()) diff --git a/dbms/src/DataTypes/DataTypeAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeAggregateFunction.cpp index 1855522c713..2c5516e100a 100644 --- a/dbms/src/DataTypes/DataTypeAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeAggregateFunction.cpp @@ -218,9 +218,11 @@ void DataTypeAggregateFunction::deserializeTextQuoted(IColumn & column, ReadBuff } -void DataTypeAggregateFunction::deserializeWholeText(IColumn &, ReadBuffer &, const FormatSettings &) const +void DataTypeAggregateFunction::deserializeWholeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const { - throw Exception("AggregateFunction data type cannot be read from text", ErrorCodes::NOT_IMPLEMENTED); + String s; + readStringUntilEOF(s, istr); + deserializeFromString(function, column, s); } diff --git a/dbms/src/Functions/FunctionsBitmap.h b/dbms/src/Functions/FunctionsBitmap.h index e87adae8064..86aa022cdb5 100644 --- a/dbms/src/Functions/FunctionsBitmap.h +++ b/dbms/src/Functions/FunctionsBitmap.h @@ -348,10 +348,10 @@ private: const UInt32 range_start = is_column_const[1] ? (*container1)[0] : (*container1)[i]; const UInt32 range_end = is_column_const[2] ? (*container2)[0] : (*container2)[i]; - auto bd2 = new AggregateFunctionGroupBitmapData(); - bd0.rbs.rb_range(range_start, range_end, bd2->rbs); - - col_to->insertFrom(reinterpret_cast(bd2)); + col_to->insertDefault(); + AggregateFunctionGroupBitmapData & bd2 + = *reinterpret_cast *>(col_to->getData()[i]); + bd0.rbs.rb_range(range_start, range_end, bd2.rbs); } block.getByPosition(result).column = std::move(col_to); } diff --git a/dbms/src/Functions/FunctionsConversion.h b/dbms/src/Functions/FunctionsConversion.h index 3baca412497..29aee42d001 100644 --- a/dbms/src/Functions/FunctionsConversion.h +++ b/dbms/src/Functions/FunctionsConversion.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -636,7 +637,7 @@ struct ConvertImplGenericFromString { ReadBufferFromMemory read_buffer(&chars[current_offset], offsets[i] - current_offset - 1); - data_type_to.deserializeAsTextEscaped(column_to, read_buffer, format_settings); + data_type_to.deserializeAsWholeText(column_to, read_buffer, format_settings); if (!read_buffer.eof()) throwExceptionForIncompletelyParsedValue(read_buffer, block, result); @@ -1669,6 +1670,21 @@ private: }; } + WrapperType createAggregateFunctionWrapper(const DataTypePtr & from_type_untyped, const DataTypeAggregateFunction * to_type) const + { + /// Conversion from String through parsing. + if (checkAndGetDataType(from_type_untyped.get())) + { + return [] (Block & block, const ColumnNumbers & arguments, const size_t result, size_t /*input_rows_count*/) + { + ConvertImplGenericFromString::execute(block, arguments, result); + }; + } + else + throw Exception{"Conversion from " + from_type_untyped->getName() + " to " + to_type->getName() + + " is not supported", ErrorCodes::CANNOT_CONVERT_TYPE}; + } + WrapperType createArrayWrapper(const DataTypePtr & from_type_untyped, const DataTypeArray * to_type) const { /// Conversion from String through parsing. @@ -2145,13 +2161,12 @@ private: case TypeIndex::Tuple: return createTupleWrapper(from_type, checkAndGetDataType(to_type.get())); + case TypeIndex::AggregateFunction: + return createAggregateFunctionWrapper(from_type, checkAndGetDataType(to_type.get())); default: break; } - /// It's possible to use ConvertImplGenericFromString to convert from String to AggregateFunction, - /// but it is disabled because deserializing aggregate functions state might be unsafe. - throw Exception{"Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported", ErrorCodes::CANNOT_CONVERT_TYPE}; } diff --git a/dbms/src/Functions/FunctionsLogical.cpp b/dbms/src/Functions/FunctionsLogical.cpp index 29035366909..75c602df088 100644 --- a/dbms/src/Functions/FunctionsLogical.cpp +++ b/dbms/src/Functions/FunctionsLogical.cpp @@ -1,6 +1,20 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + namespace DB { @@ -12,4 +26,568 @@ void registerFunctionsLogical(FunctionFactory & factory) factory.registerFunction(); } +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; + extern const int ILLEGAL_COLUMN; +} + +namespace +{ +using namespace FunctionsLogicalDetail; + +using UInt8Container = ColumnUInt8::Container; +using UInt8ColumnPtrs = std::vector; + + +MutableColumnPtr convertFromTernaryData(const UInt8Container & ternary_data, const bool make_nullable) +{ + const size_t rows_count = ternary_data.size(); + + auto new_column = ColumnUInt8::create(rows_count); + std::transform( + ternary_data.cbegin(), ternary_data.cend(), new_column->getData().begin(), + [](const auto x) { return x == Ternary::True; }); + + if (!make_nullable) + return new_column; + + auto null_column = ColumnUInt8::create(rows_count); + std::transform( + ternary_data.cbegin(), ternary_data.cend(), null_column->getData().begin(), + [](const auto x) { return x == Ternary::Null; }); + + return ColumnNullable::create(std::move(new_column), std::move(null_column)); +} + +template +bool tryConvertColumnToUInt8(const IColumn * column, UInt8Container & res) +{ + const auto col = checkAndGetColumn>(column); + if (!col) + return false; + + std::transform( + col->getData().cbegin(), col->getData().cend(), res.begin(), + [](const auto x) { return x != 0; }); + + return true; +} + +void convertColumnToUInt8(const IColumn * column, UInt8Container & res) +{ + if (!tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res) && + !tryConvertColumnToUInt8(column, res)) + throw Exception("Unexpected type of column: " + column->getName(), ErrorCodes::ILLEGAL_COLUMN); +} + + +template +static bool extractConstColumns(ColumnRawPtrs & in, UInt8 & res, Func && func) +{ + bool has_res = false; + + for (int i = static_cast(in.size()) - 1; i >= 0; --i) + { + if (!isColumnConst(*in[i])) + continue; + + UInt8 x = func((*in[i])[0]); + if (has_res) + { + res = Op::apply(res, x); + } + else + { + res = x; + has_res = true; + } + + in.erase(in.begin() + i); + } + + return has_res; +} + +template +inline bool extractConstColumns(ColumnRawPtrs & in, UInt8 & res) +{ + return extractConstColumns( + in, res, + [](const Field & value) + { + return !value.isNull() && applyVisitor(FieldVisitorConvertToNumber(), value); + } + ); +} + +template +inline bool extractConstColumnsTernary(ColumnRawPtrs & in, UInt8 & res_3v) +{ + return extractConstColumns( + in, res_3v, + [](const Field & value) + { + return value.isNull() + ? Ternary::makeValue(false, true) + : Ternary::makeValue(applyVisitor(FieldVisitorConvertToNumber(), value)); + } + ); +} + + +template +class AssociativeApplierImpl +{ + using ResultValueType = typename Op::ResultType; + +public: + /// Remembers the last N columns from `in`. + AssociativeApplierImpl(const UInt8ColumnPtrs & in) + : vec(in[in.size() - N]->getData()), next(in) {} + + /// Returns a combination of values in the i-th row of all columns stored in the constructor. + inline ResultValueType apply(const size_t i) const + { + const auto & a = vec[i]; + if constexpr (Op::isSaturable()) + return Op::isSaturatedValue(a) ? a : Op::apply(a, next.apply(i)); + else + return Op::apply(a, next.apply(i)); + } + +private: + const UInt8Container & vec; + const AssociativeApplierImpl next; +}; + +template +class AssociativeApplierImpl +{ + using ResultValueType = typename Op::ResultType; + +public: + AssociativeApplierImpl(const UInt8ColumnPtrs & in) + : vec(in[in.size() - 1]->getData()) {} + + inline ResultValueType apply(const size_t i) const { return vec[i]; } + +private: + const UInt8Container & vec; +}; + + +/// A helper class used by AssociativeGenericApplierImpl +/// Allows for on-the-fly conversion of any data type into intermediate ternary representation +using ValueGetter = std::function; + +template +struct ValueGetterBuilderImpl; + +template +struct ValueGetterBuilderImpl +{ + static ValueGetter build(const IColumn * x) + { + if (const auto nullable_column = typeid_cast(x)) + { + if (const auto nested_column = typeid_cast *>(nullable_column->getNestedColumnPtr().get())) + { + return [&null_data = nullable_column->getNullMapData(), &column_data = nested_column->getData()](size_t i) + { return Ternary::makeValue(column_data[i], null_data[i]); }; + } + else + return ValueGetterBuilderImpl::build(x); + } + else if (const auto column = typeid_cast *>(x)) + return [&column_data = column->getData()](size_t i) { return Ternary::makeValue(column_data[i]); }; + else + return ValueGetterBuilderImpl::build(x); + } +}; + +template <> +struct ValueGetterBuilderImpl<> +{ + static ValueGetter build(const IColumn * x) + { + throw Exception( + std::string("Unknown numeric column of type: ") + demangle(typeid(x).name()), + ErrorCodes::LOGICAL_ERROR); + } +}; + +using ValueGetterBuilder = + ValueGetterBuilderImpl; + +/// This class together with helper class ValueGetterBuilder can be used with columns of arbitrary data type +/// Allows for on-the-fly conversion of any type of data into intermediate ternary representation +/// and eliminates the need to materialize data columns in intermediate representation +template +class AssociativeGenericApplierImpl +{ + using ResultValueType = typename Op::ResultType; + +public: + /// Remembers the last N columns from `in`. + AssociativeGenericApplierImpl(const ColumnRawPtrs & in) + : val_getter{ValueGetterBuilder::build(in[in.size() - N])}, next{in} {} + + /// Returns a combination of values in the i-th row of all columns stored in the constructor. + inline ResultValueType apply(const size_t i) const + { + const auto a = val_getter(i); + if constexpr (Op::isSaturable()) + return Op::isSaturatedValue(a) ? a : Op::apply(a, next.apply(i)); + else + return Op::apply(a, next.apply(i)); + } + +private: + const ValueGetter val_getter; + const AssociativeGenericApplierImpl next; +}; + + +template +class AssociativeGenericApplierImpl +{ + using ResultValueType = typename Op::ResultType; + +public: + /// Remembers the last N columns from `in`. + AssociativeGenericApplierImpl(const ColumnRawPtrs & in) + : val_getter{ValueGetterBuilder::build(in[in.size() - 1])} {} + + inline ResultValueType apply(const size_t i) const { return val_getter(i); } + +private: + const ValueGetter val_getter; +}; + + +/// Apply target function by feeding it "batches" of N columns +/// Combining 10 columns per pass is the fastest for large block sizes. +/// For small block sizes - more columns is faster. +template < + typename Op, template typename OperationApplierImpl, size_t N = 10> +struct OperationApplier +{ + template + static void apply(Columns & in, ResultColumn & result) + { + while (in.size() > 1) + { + doBatchedApply(in, result->getData()); + in.push_back(result.get()); + } + } + + template + static void NO_INLINE doBatchedApply(Columns & in, ResultData & result_data) + { + if (N > in.size()) + { + OperationApplier::doBatchedApply(in, result_data); + return; + } + + const OperationApplierImpl operationApplierImpl(in); + size_t i = 0; + for (auto & res : result_data) + res = operationApplierImpl.apply(i++); + + in.erase(in.end() - N, in.end()); + } +}; + +template < + typename Op, template typename OperationApplierImpl> +struct OperationApplier +{ + template + static void NO_INLINE doBatchedApply(Columns &, Result &) + { + throw Exception( + "OperationApplier<...>::apply(...): not enough arguments to run this method", + ErrorCodes::LOGICAL_ERROR); + } +}; + + +template +static void executeForTernaryLogicImpl(ColumnRawPtrs arguments, ColumnWithTypeAndName & result_info, size_t input_rows_count) +{ + /// Combine all constant columns into a single constant value. + UInt8 const_3v_value = 0; + const bool has_consts = extractConstColumnsTernary(arguments, const_3v_value); + + /// If the constant value uniquely determines the result, return it. + if (has_consts && (arguments.empty() || (Op::isSaturable() && Op::isSaturatedValue(const_3v_value)))) + { + result_info.column = ColumnConst::create( + convertFromTernaryData(UInt8Container({const_3v_value}), result_info.type->isNullable()), + input_rows_count + ); + return; + } + + const auto result_column = ColumnUInt8::create(input_rows_count); + MutableColumnPtr const_column_holder; + if (has_consts) + { + const_column_holder = + convertFromTernaryData(UInt8Container(input_rows_count, const_3v_value), const_3v_value == Ternary::Null); + arguments.push_back(const_column_holder.get()); + } + + OperationApplier::apply(arguments, result_column); + + result_info.column = convertFromTernaryData(result_column->getData(), result_info.type->isNullable()); +} + + +template +struct TypedExecutorInvoker; + +template +using FastApplierImpl = + TypedExecutorInvoker; + +template +struct TypedExecutorInvoker +{ + template + static void apply(const ColumnVector & x, const IColumn & y, Result & result) + { + if (const auto column = typeid_cast *>(&y)) + std::transform( + x.getData().cbegin(), x.getData().cend(), + column->getData().cbegin(), result.begin(), + [](const auto a, const auto b) { return Op::apply(!!a, !!b); }); + else + TypedExecutorInvoker::template apply(x, y, result); + } + + template + static void apply(const IColumn & x, const IColumn & y, Result & result) + { + if (const auto column = typeid_cast *>(&x)) + FastApplierImpl::template apply(*column, y, result); + else + TypedExecutorInvoker::apply(x, y, result); + } +}; + +template +struct TypedExecutorInvoker +{ + template + static void apply(const ColumnVector &, const IColumn & y, Result &) + { + throw Exception(std::string("Unknown numeric column y of type: ") + demangle(typeid(y).name()), ErrorCodes::LOGICAL_ERROR); + } + + template + static void apply(const IColumn & x, const IColumn &, Result &) + { + throw Exception(std::string("Unknown numeric column x of type: ") + demangle(typeid(x).name()), ErrorCodes::LOGICAL_ERROR); + } +}; + + +template +static void basicExecuteImpl(ColumnRawPtrs arguments, ColumnWithTypeAndName & result_info, size_t input_rows_count) +{ + /// Combine all constant columns into a single constant value. + UInt8 const_val = 0; + bool has_consts = extractConstColumns(arguments, const_val); + + /// If the constant value uniquely determines the result, return it. + if (has_consts && (arguments.empty() || Op::apply(const_val, 0) == Op::apply(const_val, 1))) + { + if (!arguments.empty()) + const_val = Op::apply(const_val, 0); + result_info.column = DataTypeUInt8().createColumnConst(input_rows_count, toField(const_val)); + return; + } + + /// If the constant value is a neutral element, let's forget about it. + if (has_consts && Op::apply(const_val, 0) == 0 && Op::apply(const_val, 1) == 1) + has_consts = false; + + UInt8ColumnPtrs uint8_args; + + auto col_res = ColumnUInt8::create(); + UInt8Container & vec_res = col_res->getData(); + if (has_consts) + { + vec_res.assign(input_rows_count, const_val); + uint8_args.push_back(col_res.get()); + } + else + { + vec_res.resize(input_rows_count); + } + + /// FastPath detection goes in here + if (arguments.size() == (has_consts ? 1 : 2)) + { + if (has_consts) + FastApplierImpl::apply(*arguments[0], *col_res, col_res->getData()); + else + FastApplierImpl::apply(*arguments[0], *arguments[1], col_res->getData()); + + result_info.column = std::move(col_res); + return; + } + + /// Convert all columns to UInt8 + Columns converted_columns; + for (const IColumn * column : arguments) + { + if (auto uint8_column = checkAndGetColumn(column)) + uint8_args.push_back(uint8_column); + else + { + auto converted_column = ColumnUInt8::create(input_rows_count); + convertColumnToUInt8(column, converted_column->getData()); + uint8_args.push_back(converted_column.get()); + converted_columns.emplace_back(std::move(converted_column)); + } + } + + OperationApplier::apply(uint8_args, col_res); + + /// This is possible if there is exactly one non-constant among the arguments, and it is of type UInt8. + if (uint8_args[0] != col_res.get()) + vec_res.assign(uint8_args[0]->getData()); + + result_info.column = std::move(col_res); +} + +} + +template +DataTypePtr FunctionAnyArityLogical::getReturnTypeImpl(const DataTypes & arguments) const +{ + if (arguments.size() < 2) + throw Exception("Number of arguments for function \"" + getName() + "\" should be at least 2: passed " + + toString(arguments.size()), + ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION); + + bool has_nullable_arguments = false; + for (size_t i = 0; i < arguments.size(); ++i) + { + const auto & arg_type = arguments[i]; + + if (!has_nullable_arguments) + { + has_nullable_arguments = arg_type->isNullable(); + if (has_nullable_arguments && !Impl::specialImplementationForNulls()) + throw Exception("Logical error: Unexpected type of argument for function \"" + getName() + "\": " + " argument " + toString(i + 1) + " is of type " + arg_type->getName(), ErrorCodes::LOGICAL_ERROR); + } + + if (!(isNativeNumber(arg_type) + || (Impl::specialImplementationForNulls() && (arg_type->onlyNull() || isNativeNumber(removeNullable(arg_type)))))) + throw Exception("Illegal type (" + + arg_type->getName() + + ") of " + toString(i + 1) + " argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + auto result_type = std::make_shared(); + return has_nullable_arguments + ? makeNullable(result_type) + : result_type; +} + +template +void FunctionAnyArityLogical::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_index, size_t input_rows_count) +{ + ColumnRawPtrs args_in; + for (const auto arg_index : arguments) + args_in.push_back(block.getByPosition(arg_index).column.get()); + + auto & result_info = block.getByPosition(result_index); + if (result_info.type->isNullable()) + executeForTernaryLogicImpl(std::move(args_in), result_info, input_rows_count); + else + basicExecuteImpl(std::move(args_in), result_info, input_rows_count); +} + + +template +struct UnaryOperationImpl +{ + using ResultType = typename Op::ResultType; + using ArrayA = typename ColumnVector::Container; + using ArrayC = typename ColumnVector::Container; + + static void NO_INLINE vector(const ArrayA & a, ArrayC & c) + { + std::transform( + a.cbegin(), a.cend(), c.begin(), + [](const auto x) { return Op::apply(x); }); + } +}; + +template