Merge branch 'clickhouse-4013' of https://github.com/anrodigina/ClickHouse into clickhouse-4013

This commit is contained in:
Anastasiya Rodigina 2019-05-18 15:54:50 +03:00
commit 8c2630eddc
332 changed files with 10259 additions and 1667 deletions

6
.gitmodules vendored
View File

@ -79,3 +79,9 @@
[submodule "contrib/hyperscan"]
path = contrib/hyperscan
url = https://github.com/ClickHouse-Extras/hyperscan.git
[submodule "contrib/simdjson"]
path = contrib/simdjson
url = https://github.com/lemire/simdjson.git
[submodule "contrib/rapidjson"]
path = contrib/rapidjson
url = https://github.com/Tencent/rapidjson

View File

@ -1,6 +1,15 @@
project(ClickHouse)
cmake_minimum_required(VERSION 3.3)
cmake_policy(SET CMP0023 NEW)
foreach(policy
CMP0023
CMP0074 # CMake 3.12
)
if(POLICY ${policy})
cmake_policy(SET ${policy} NEW)
endif()
endforeach()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/")
set(CMAKE_EXPORT_COMPILE_COMMANDS 1) # Write compile_commands.json
set(CMAKE_LINK_DEPENDS_NO_SHARED 1) # Do not relink all depended targets on .so
@ -301,6 +310,7 @@ include (cmake/find_rt.cmake)
include (cmake/find_execinfo.cmake)
include (cmake/find_readline_edit.cmake)
include (cmake/find_re2.cmake)
include (cmake/find_libgsasl.cmake)
include (cmake/find_rdkafka.cmake)
include (cmake/find_capnp.cmake)
include (cmake/find_llvm.cmake)
@ -308,7 +318,6 @@ include (cmake/find_cpuid.cmake) # Freebsd, bundled
if (NOT USE_CPUID)
include (cmake/find_cpuinfo.cmake) # Debian
endif()
include (cmake/find_libgsasl.cmake)
include (cmake/find_libxml2.cmake)
include (cmake/find_brotli.cmake)
include (cmake/find_protobuf.cmake)
@ -318,6 +327,8 @@ include (cmake/find_consistent-hashing.cmake)
include (cmake/find_base64.cmake)
include (cmake/find_hyperscan.cmake)
include (cmake/find_lfalloc.cmake)
include (cmake/find_simdjson.cmake)
include (cmake/find_rapidjson.cmake)
find_contrib_lib(cityhash)
find_contrib_lib(farmhash)
find_contrib_lib(metrohash)

View File

@ -12,8 +12,8 @@ ClickHouse is an open-source column-oriented database management system that all
* You can also [fill this form](https://forms.yandex.com/surveys/meet-yandex-clickhouse-team/) to meet Yandex ClickHouse team in person.
## Upcoming Events
* [ClickHouse Community Meetup in Limassol](https://www.facebook.com/events/386638262181785/) on May 7.
* ClickHouse at [Percona Live 2019](https://www.percona.com/live/19/other-open-source-databases-track) in Austin on May 28-30.
* [ClickHouse Community Meetup in San Francisco](https://www.meetup.com/San-Francisco-Bay-Area-ClickHouse-Meetup/events/261110652/) on June 4.
* [ClickHouse Community Meetup in Beijing](https://www.huodongxing.com/event/2483759276200) on June 8.
* [ClickHouse Community Meetup in Shenzhen](https://www.huodongxing.com/event/3483759917300) on October 20.
* [ClickHouse Community Meetup in Shanghai](https://www.huodongxing.com/event/4483760336000) on October 27.

View File

@ -1,9 +1,12 @@
option (USE_INTERNAL_BOOST_LIBRARY "Set to FALSE to use system boost library instead of bundled" ${NOT_UNBUNDLED})
# Test random file existing in all package variants
if (USE_INTERNAL_BOOST_LIBRARY AND NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/boost/libs/system/src/error_code.cpp")
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/boost/libs/system/src/error_code.cpp")
if(USE_INTERNAL_BOOST_LIBRARY)
message(WARNING "submodules in contrib/boost is missing. to fix try run: \n git submodule update --init --recursive")
endif()
set (USE_INTERNAL_BOOST_LIBRARY 0)
set (MISSING_INTERNAL_BOOST_LIBRARY 1)
endif ()
if (NOT USE_INTERNAL_BOOST_LIBRARY)
@ -21,10 +24,9 @@ if (NOT USE_INTERNAL_BOOST_LIBRARY)
set (Boost_INCLUDE_DIRS "")
set (Boost_SYSTEM_LIBRARY "")
endif ()
endif ()
if (NOT Boost_SYSTEM_LIBRARY)
if (NOT Boost_SYSTEM_LIBRARY AND NOT MISSING_INTERNAL_BOOST_LIBRARY)
set (USE_INTERNAL_BOOST_LIBRARY 1)
set (Boost_SYSTEM_LIBRARY boost_system_internal)
set (Boost_PROGRAM_OPTIONS_LIBRARY boost_program_options_internal)
@ -44,7 +46,6 @@ if (NOT Boost_SYSTEM_LIBRARY)
# For packaged version:
list (APPEND Boost_INCLUDE_DIRS "${ClickHouse_SOURCE_DIR}/contrib/boost")
endif ()
message (STATUS "Using Boost: ${Boost_INCLUDE_DIRS} : ${Boost_PROGRAM_OPTIONS_LIBRARY},${Boost_SYSTEM_LIBRARY},${Boost_FILESYSTEM_LIBRARY},${Boost_REGEX_LIBRARY}")

View File

@ -1,6 +1,9 @@
option(ENABLE_ICU "Enable ICU" ON)
if(ENABLE_ICU)
if (APPLE)
set(ICU_ROOT "/usr/local/opt/icu4c" CACHE STRING "")
endif()
find_package(ICU COMPONENTS i18n uc data) # TODO: remove Modules/FindICU.cmake after cmake 3.7
#set (ICU_LIBRARIES ${ICU_I18N_LIBRARY} ${ICU_UC_LIBRARY} ${ICU_DATA_LIBRARY} CACHE STRING "")
if(ICU_FOUND)

View File

@ -1,4 +1,4 @@
if (NOT SANITIZE AND NOT ARCH_ARM AND NOT ARCH_32 AND NOT ARCH_PPC64LE AND NOT OS_FREEBSD)
if (NOT SANITIZE AND NOT ARCH_ARM AND NOT ARCH_32 AND NOT ARCH_PPC64LE AND NOT OS_FREEBSD AND NOT APPLE)
option (ENABLE_LFALLOC "Set to FALSE to use system libgsasl library instead of bundled" ${NOT_UNBUNDLED})
endif ()

View File

@ -22,4 +22,8 @@ elseif (NOT MISSING_INTERNAL_LIBGSASL_LIBRARY AND NOT APPLE AND NOT ARCH_32)
set (LIBGSASL_LIBRARY libgsasl)
endif ()
message (STATUS "Using libgsasl: ${LIBGSASL_INCLUDE_DIR} : ${LIBGSASL_LIBRARY}")
if(LIBGSASL_LIBRARY AND LIBGSASL_INCLUDE_DIR)
set (USE_LIBGSASL 1)
endif()
message (STATUS "Using libgsasl=${USE_LIBGSASL}: ${LIBGSASL_INCLUDE_DIR} : ${LIBGSASL_LIBRARY}")

View File

@ -0,0 +1,9 @@
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/rapidjson/include/rapidjson/rapidjson.h")
message (WARNING "submodule contrib/rapidjson is missing. to fix try run: \n git submodule update --init --recursive")
return()
endif ()
option (USE_RAPIDJSON "Use rapidjson" ON)
set (RAPIDJSON_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/rapidjson/include")
message(STATUS "Using rapidjson=${USE_RAPIDJSON}: ${RAPIDJSON_INCLUDE_DIR}")

View File

@ -10,7 +10,7 @@ endif ()
if (ENABLE_RDKAFKA)
if (OS_LINUX AND NOT ARCH_ARM)
if (OS_LINUX AND NOT ARCH_ARM AND USE_LIBGSASL)
option (USE_INTERNAL_RDKAFKA_LIBRARY "Set to FALSE to use system librdkafka instead of the bundled" ${NOT_UNBUNDLED})
endif ()

View File

@ -1,5 +1,13 @@
option (USE_INTERNAL_RE2_LIBRARY "Set to FALSE to use system re2 library instead of bundled [slower]" ${NOT_UNBUNDLED})
if(NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/re2/CMakeLists.txt")
if(USE_INTERNAL_RE2_LIBRARY)
message(WARNING "submodule contrib/re2 is missing. to fix try run: \n git submodule update --init --recursive")
endif()
set(USE_INTERNAL_RE2_LIBRARY 0)
set(MISSING_INTERNAL_RE2_LIBRARY 1)
endif()
if (NOT USE_INTERNAL_RE2_LIBRARY)
find_library (RE2_LIBRARY re2)
find_path (RE2_INCLUDE_DIR NAMES re2/re2.h PATHS ${RE2_INCLUDE_PATHS})

14
cmake/find_simdjson.cmake Normal file
View File

@ -0,0 +1,14 @@
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/simdjson/include/simdjson/jsonparser.h")
message (WARNING "submodule contrib/simdjson is missing. to fix try run: \n git submodule update --init --recursive")
return()
endif ()
if (NOT HAVE_AVX2)
message (WARNING "submodule contrib/simdjson requires AVX2 support")
return()
endif ()
option (USE_SIMDJSON "Use simdjson" ON)
set (SIMDJSON_LIBRARY "simdjson")
message(STATUS "Using simdjson=${USE_SIMDJSON}: ${SIMDJSON_LIBRARY}")

View File

@ -2,11 +2,6 @@ if (NOT OS_FREEBSD AND NOT ARCH_32)
option (USE_INTERNAL_ZLIB_LIBRARY "Set to FALSE to use system zlib library instead of bundled" ${NOT_UNBUNDLED})
endif ()
if (NOT USE_INTERNAL_ZLIB_LIBRARY)
find_package (ZLIB)
endif ()
if (NOT ZLIB_FOUND)
if (NOT MSVC)
set (INTERNAL_ZLIB_NAME "zlib-ng" CACHE INTERNAL "")
else ()
@ -16,6 +11,19 @@ if (NOT ZLIB_FOUND)
endif ()
endif ()
if(NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}/zlib.h")
if(USE_INTERNAL_ZLIB_LIBRARY)
message(WARNING "submodule contrib/${INTERNAL_ZLIB_NAME} is missing. to fix try run: \n git submodule update --init --recursive")
endif()
set(USE_INTERNAL_ZLIB_LIBRARY 0)
set(MISSING_INTERNAL_ZLIB_LIBRARY 1)
endif()
if (NOT USE_INTERNAL_ZLIB_LIBRARY)
find_package (ZLIB)
endif ()
if (NOT ZLIB_FOUND AND NOT MISSING_INTERNAL_ZLIB_LIBRARY)
set (USE_INTERNAL_ZLIB_LIBRARY 1)
set (ZLIB_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}" "${ClickHouse_BINARY_DIR}/contrib/${INTERNAL_ZLIB_NAME}" CACHE INTERNAL "") # generated zconf.h
set (ZLIB_INCLUDE_DIRS ${ZLIB_INCLUDE_DIR}) # for poco

View File

@ -1,8 +1,11 @@
option (USE_INTERNAL_ZSTD_LIBRARY "Set to FALSE to use system zstd library instead of bundled" ${NOT_UNBUNDLED})
if (USE_INTERNAL_ZSTD_LIBRARY AND NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/zstd/lib/zstd.h")
if(NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/zstd/lib/zstd.h")
if(USE_INTERNAL_ZSTD_LIBRARY)
message(WARNING "submodule contrib/zstd is missing. to fix try run: \n git submodule update --init --recursive")
endif()
set(USE_INTERNAL_ZSTD_LIBRARY 0)
set(MISSING_INTERNAL_ZSTD_LIBRARY 1)
endif()
if (NOT USE_INTERNAL_ZSTD_LIBRARY)
@ -11,7 +14,7 @@ if (NOT USE_INTERNAL_ZSTD_LIBRARY)
endif ()
if (ZSTD_LIBRARY AND ZSTD_INCLUDE_DIR)
else ()
elseif (NOT MISSING_INTERNAL_ZSTD_LIBRARY)
set (USE_INTERNAL_ZSTD_LIBRARY 1)
set (ZSTD_LIBRARY zstd)
set (ZSTD_INCLUDE_DIR ${ClickHouse_SOURCE_DIR}/contrib/zstd/lib)

View File

@ -227,7 +227,7 @@ if (USE_INTERNAL_POCO_LIBRARY)
set (ENABLE_TESTS 0)
set (POCO_ENABLE_TESTS 0)
set (CMAKE_DISABLE_FIND_PACKAGE_ZLIB 1)
if (MSVC)
if (MSVC OR NOT USE_POCO_DATAODBC)
set (ENABLE_DATA_ODBC 0 CACHE INTERNAL "") # TODO (build fail)
endif ()
add_subdirectory (poco)
@ -313,3 +313,7 @@ endif()
if (USE_INTERNAL_HYPERSCAN_LIBRARY)
add_subdirectory (hyperscan)
endif()
if (USE_SIMDJSON)
add_subdirectory (simdjson-cmake)
endif()

2
contrib/boost vendored

@ -1 +1 @@
Subproject commit 471ea208abb92a5cba7d3a08a819bb728f27e95f
Subproject commit 79bf85ea99c05ba4fb6959474d4464ab126f8973

View File

@ -33,6 +33,7 @@ set(SRCS
${RDKAFKA_SOURCE_DIR}/rdkafka_roundrobin_assignor.c
${RDKAFKA_SOURCE_DIR}/rdkafka_sasl.c
${RDKAFKA_SOURCE_DIR}/rdkafka_sasl_plain.c
${RDKAFKA_SOURCE_DIR}/rdkafka_sasl_scram.c
${RDKAFKA_SOURCE_DIR}/rdkafka_subscription.c
${RDKAFKA_SOURCE_DIR}/rdkafka_timer.c
${RDKAFKA_SOURCE_DIR}/rdkafka_topic.c
@ -58,7 +59,7 @@ add_library(rdkafka ${SRCS})
target_include_directories(rdkafka SYSTEM PUBLIC include)
target_include_directories(rdkafka SYSTEM PUBLIC ${RDKAFKA_SOURCE_DIR}) # Because weird logic with "include_next" is used.
target_include_directories(rdkafka SYSTEM PRIVATE ${ZSTD_INCLUDE_DIR}/common) # Because wrong path to "zstd_errors.h" is used.
target_link_libraries(rdkafka PUBLIC ${ZLIB_LIBRARIES} ${ZSTD_LIBRARY} ${LZ4_LIBRARY})
target_link_libraries(rdkafka PUBLIC ${ZLIB_LIBRARIES} ${ZSTD_LIBRARY} ${LZ4_LIBRARY} ${LIBGSASL_LIBRARY})
if(OPENSSL_SSL_LIBRARY AND OPENSSL_CRYPTO_LIBRARY)
target_link_libraries(rdkafka PUBLIC ${OPENSSL_SSL_LIBRARY} ${OPENSSL_CRYPTO_LIBRARY})
endif()

View File

@ -12,7 +12,7 @@
#define ENABLE_SHAREDPTR_DEBUG 0
#define ENABLE_LZ4_EXT 1
#define ENABLE_SSL 1
//#define ENABLE_SASL 1
#define ENABLE_SASL 1
#define MKL_APP_NAME "librdkafka"
#define MKL_APP_DESC_ONELINE "The Apache Kafka C/C++ library"
// distro
@ -62,7 +62,7 @@
// libssl
#define WITH_SSL 1
// WITH_SASL_SCRAM
//#define WITH_SASL_SCRAM 1
#define WITH_SASL_SCRAM 1
// crc32chw
#if !defined(__PPC__)
#define WITH_CRC32C_HW 1

1
contrib/rapidjson vendored Submodule

@ -0,0 +1 @@
Subproject commit 01950eb7acec78818d68b762efc869bba2420d82

1
contrib/simdjson vendored Submodule

@ -0,0 +1 @@
Subproject commit 14cd1f7a0b0563db78bda8053a9f6ac2ea95a441

View File

@ -0,0 +1,18 @@
if (NOT HAVE_AVX2)
message (FATAL_ERROR "No AVX2 support")
endif ()
set(SIMDJSON_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/simdjson/include")
set(SIMDJSON_SRC_DIR "${SIMDJSON_INCLUDE_DIR}/../src")
set(SIMDJSON_SRC
${SIMDJSON_SRC_DIR}/jsonioutil.cpp
${SIMDJSON_SRC_DIR}/jsonminifier.cpp
${SIMDJSON_SRC_DIR}/jsonparser.cpp
${SIMDJSON_SRC_DIR}/stage1_find_marks.cpp
${SIMDJSON_SRC_DIR}/stage2_build_tape.cpp
${SIMDJSON_SRC_DIR}/parsedjson.cpp
${SIMDJSON_SRC_DIR}/parsedjsoniterator.cpp
)
add_library(${SIMDJSON_LIBRARY} ${SIMDJSON_SRC})
target_include_directories(${SIMDJSON_LIBRARY} PUBLIC "${SIMDJSON_INCLUDE_DIR}")
target_compile_options(${SIMDJSON_LIBRARY} PRIVATE -mavx2 -mbmi -mbmi2 -mpclmul)

View File

@ -189,8 +189,17 @@ target_link_libraries (clickhouse_common_io
${Poco_Net_LIBRARY}
${Poco_Util_LIBRARY}
${Poco_Foundation_LIBRARY}
${RE2_LIBRARY}
${RE2_ST_LIBRARY}
)
if(RE2_LIBRARY)
target_link_libraries(clickhouse_common_io PUBLIC ${RE2_LIBRARY})
endif()
if(RE2_ST_LIBRARY)
target_link_libraries(clickhouse_common_io PUBLIC ${RE2_ST_LIBRARY})
endif()
target_link_libraries(clickhouse_common_io
PUBLIC
${CITYHASH_LIBRARIES}
PRIVATE
${ZLIB_LIBRARIES}
@ -208,7 +217,9 @@ target_link_libraries (clickhouse_common_io
)
if(RE2_INCLUDE_DIR)
target_include_directories(clickhouse_common_io SYSTEM BEFORE PUBLIC ${RE2_INCLUDE_DIR})
endif()
if (USE_LFALLOC)
target_include_directories (clickhouse_common_io SYSTEM BEFORE PUBLIC ${LFALLOC_INCLUDE_DIR})

View File

@ -209,6 +209,9 @@ else ()
install (FILES ${CMAKE_CURRENT_BINARY_DIR}/clickhouse-obfuscator DESTINATION ${CMAKE_INSTALL_BINDIR} COMPONENT clickhouse)
list(APPEND CLICKHOUSE_BUNDLE clickhouse-obfuscator)
endif ()
if(ENABLE_CLICKHOUSE_ODBC_BRIDGE)
list(APPEND CLICKHOUSE_BUNDLE clickhouse-odbc-bridge)
endif()
# install always because depian package want this files:
add_custom_target (clickhouse-clang ALL COMMAND ${CMAKE_COMMAND} -E create_symlink clickhouse clickhouse-clang DEPENDS clickhouse)

View File

@ -1,5 +1,5 @@
set(CLICKHOUSE_COPIER_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/ClusterCopier.cpp)
set(CLICKHOUSE_COPIER_LINK PRIVATE clickhouse_functions clickhouse_table_functions clickhouse_aggregate_functions PUBLIC daemon)
set(CLICKHOUSE_COPIER_LINK PRIVATE clickhouse_functions clickhouse_table_functions clickhouse_aggregate_functions clickhouse_dictionaries PUBLIC daemon)
set(CLICKHOUSE_COPIER_INCLUDE SYSTEM PRIVATE ${PCG_RANDOM_INCLUDE_DIR})
clickhouse_program_add(copier)

View File

@ -63,6 +63,7 @@
#include <AggregateFunctions/registerAggregateFunctions.h>
#include <Storages/registerStorages.h>
#include <Storages/StorageDistributed.h>
#include <Dictionaries/registerDictionaries.h>
#include <Databases/DatabaseMemory.h>
#include <Common/StatusFile.h>
@ -2169,6 +2170,7 @@ void ClusterCopierApp::mainImpl()
registerAggregateFunctions();
registerTableFunctions();
registerStorages();
registerDictionaries();
static const std::string default_database = "_local";
context->addDatabase(default_database, std::make_shared<DatabaseMemory>(default_database));

View File

@ -370,7 +370,7 @@ try
Poco::Logger * log = &Poco::Logger::get("PerformanceTestSuite");
if (options.count("help"))
{
std::cout << "Usage: " << argv[0] << " [options] [test_file ...] [tests_folder]\n";
std::cout << "Usage: " << argv[0] << " [options]\n";
std::cout << desc << "\n";
return 0;
}

View File

@ -1,4 +1,5 @@
#include "TestStats.h"
#include <algorithm>
namespace DB
{
@ -92,11 +93,10 @@ void TestStats::update_average_speed(
avg_speed_value /= number_of_info_batches;
if (avg_speed_first == 0)
{
avg_speed_first = avg_speed_value;
}
if (std::abs(avg_speed_value - avg_speed_first) >= precision)
auto [min, max] = std::minmax(avg_speed_value, avg_speed_first);
if (1 - min / max >= precision)
{
avg_speed_first = avg_speed_value;
avg_speed_watch.restart();

View File

@ -40,11 +40,11 @@ struct TestStats
double avg_rows_speed_value = 0;
double avg_rows_speed_first = 0;
static inline double avg_rows_speed_precision = 0.001;
static inline double avg_rows_speed_precision = 0.005;
double avg_bytes_speed_value = 0;
double avg_bytes_speed_first = 0;
static inline double avg_bytes_speed_precision = 0.001;
static inline double avg_bytes_speed_precision = 0.005;
size_t number_of_rows_speed_info_batches = 0;
size_t number_of_bytes_speed_info_batches = 0;

View File

@ -79,6 +79,7 @@ namespace ErrorCodes
extern const int SYSTEM_ERROR;
extern const int FAILED_TO_GETPWUID;
extern const int MISMATCHING_USERS_FOR_PROCESS_AND_DATA;
extern const int NETWORK_ERROR;
}
@ -587,12 +588,12 @@ int Server::main(const std::vector<std::string> & /*args*/)
return socket_address;
};
auto socket_bind_listen = [&](auto & socket, const std::string & host, UInt16 port, bool secure = 0)
auto socket_bind_listen = [&](auto & socket, const std::string & host, UInt16 port, [[maybe_unused]] bool secure = 0)
{
auto address = make_socket_address(host, port);
#if !defined(POCO_CLICKHOUSE_PATCH) || POCO_VERSION <= 0x02000000 // TODO: fill correct version
#if !defined(POCO_CLICKHOUSE_PATCH) || POCO_VERSION < 0x01090100
if (secure)
/// Bug in old poco, listen() after bind() with reusePort param will fail because have no implementation in SecureServerSocketImpl
/// Bug in old (<1.9.1) poco, listen() after bind() with reusePort param will fail because have no implementation in SecureServerSocketImpl
/// https://github.com/pocoproject/poco/pull/2257
socket.bind(address, /* reuseAddress = */ true);
else
@ -611,13 +612,15 @@ int Server::main(const std::vector<std::string> & /*args*/)
for (const auto & listen_host : listen_hosts)
{
/// For testing purposes, user may omit tcp_port or http_port or https_port in configuration file.
uint16_t listen_port = 0;
try
{
/// HTTP
if (config().has("http_port"))
{
Poco::Net::ServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("http_port"));
listen_port = config().getInt("http_port");
auto address = socket_bind_listen(socket, listen_host, listen_port);
socket.setReceiveTimeout(settings.http_receive_timeout);
socket.setSendTimeout(settings.http_send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
@ -634,7 +637,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
{
#if USE_POCO_NETSSL
Poco::Net::SecureServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("https_port"), /* secure = */ true);
listen_port = config().getInt("https_port");
auto address = socket_bind_listen(socket, listen_host, listen_port, /* secure = */ true);
socket.setReceiveTimeout(settings.http_receive_timeout);
socket.setSendTimeout(settings.http_send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
@ -654,7 +658,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
if (config().has("tcp_port"))
{
Poco::Net::ServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("tcp_port"));
listen_port = config().getInt("tcp_port");
auto address = socket_bind_listen(socket, listen_host, listen_port);
socket.setReceiveTimeout(settings.receive_timeout);
socket.setSendTimeout(settings.send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::TCPServer>(
@ -671,7 +676,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
{
#if USE_POCO_NETSSL
Poco::Net::SecureServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("tcp_port_secure"), /* secure = */ true);
listen_port = config().getInt("tcp_port_secure");
auto address = socket_bind_listen(socket, listen_host, listen_port, /* secure = */ true);
socket.setReceiveTimeout(settings.receive_timeout);
socket.setSendTimeout(settings.send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::TCPServer>(
@ -694,7 +700,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
if (config().has("interserver_http_port"))
{
Poco::Net::ServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("interserver_http_port"));
listen_port = config().getInt("interserver_http_port");
auto address = socket_bind_listen(socket, listen_host, listen_port);
socket.setReceiveTimeout(settings.http_receive_timeout);
socket.setSendTimeout(settings.http_send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
@ -710,7 +717,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
{
#if USE_POCO_NETSSL
Poco::Net::SecureServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("interserver_https_port"), /* secure = */ true);
listen_port = config().getInt("interserver_https_port");
auto address = socket_bind_listen(socket, listen_host, listen_port, /* secure = */ true);
socket.setReceiveTimeout(settings.http_receive_timeout);
socket.setSendTimeout(settings.http_send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
@ -726,16 +734,17 @@ int Server::main(const std::vector<std::string> & /*args*/)
#endif
}
}
catch (const Poco::Net::NetException & e)
catch (const Poco::Exception & e)
{
std::string message = "Listen [" + listen_host + "]:" + std::to_string(listen_port) + " failed: " + std::to_string(e.code()) + ": " + e.what() + ": " + e.message();
if (listen_try)
LOG_ERROR(log, "Listen [" << listen_host << "]: " << e.code() << ": " << e.what() << ": " << e.message()
LOG_ERROR(log, message
<< " If it is an IPv6 or IPv4 address and your host has disabled IPv6 or IPv4, then consider to "
"specify not disabled IPv4 or IPv6 address to listen in <listen_host> element of configuration "
"file. Example for disabled IPv6: <listen_host>0.0.0.0</listen_host> ."
" Example for disabled IPv4: <listen_host>::</listen_host>");
else
throw;
throw Exception{message, ErrorCodes::NETWORK_ERROR};
}
}

View File

@ -0,0 +1,474 @@
#include "AggregateFunctionMLMethod.h"
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/castColumn.h>
#include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h>
#include "AggregateFunctionFactory.h"
#include "FactoryHelpers.h"
#include "Helpers.h"
namespace DB
{
namespace
{
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
template <class Method>
AggregateFunctionPtr
createAggregateFunctionMLMethod(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (parameters.size() > 4)
throw Exception(
"Aggregate function " + name
+ " requires at most four parameters: learning_rate, l2_regularization_coef, mini-batch size and weights_updater "
"method",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (argument_types.size() < 2)
throw Exception(
"Aggregate function " + name + " requires at least two arguments: target and model's parameters",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!isNumber(argument_types[i]))
throw Exception(
"Argument " + std::to_string(i) + " of type " + argument_types[i]->getName()
+ " must be numeric for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
/// Such default parameters were picked because they did good on some tests,
/// though it still requires to fit parameters to achieve better result
auto learning_rate = Float64(0.01);
auto l2_reg_coef = Float64(0.01);
UInt32 batch_size = 1;
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>();
std::shared_ptr<IGradientComputer> gradient_computer;
if (!parameters.empty())
{
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
}
if (parameters.size() > 1)
{
l2_reg_coef = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[1]);
}
if (parameters.size() > 2)
{
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]);
}
if (parameters.size() > 3)
{
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'")
{
weights_updater = std::make_shared<StochasticGradientDescent>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'")
{
weights_updater = std::make_shared<Momentum>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'")
{
weights_updater = std::make_shared<Nesterov>();
}
else
{
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}
if (std::is_same<Method, FuncLinearRegression>::value)
{
gradient_computer = std::make_shared<LinearRegression>();
}
else if (std::is_same<Method, FuncLogisticRegression>::value)
{
gradient_computer = std::make_shared<LogisticRegression>();
}
else
{
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return std::make_shared<Method>(
argument_types.size() - 1,
gradient_computer,
weights_updater,
learning_rate,
l2_reg_coef,
batch_size,
argument_types,
parameters);
}
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
}
LinearModelData::LinearModelData(
Float64 learning_rate,
Float64 l2_reg_coef,
UInt32 param_num,
UInt32 batch_capacity,
std::shared_ptr<DB::IGradientComputer> gradient_computer,
std::shared_ptr<DB::IWeightsUpdater> weights_updater)
: learning_rate(learning_rate)
, l2_reg_coef(l2_reg_coef)
, batch_capacity(batch_capacity)
, batch_size(0)
, gradient_computer(std::move(gradient_computer))
, weights_updater(std::move(weights_updater))
{
weights.resize(param_num, Float64{0.0});
gradient_batch.resize(param_num + 1, Float64{0.0});
}
void LinearModelData::update_state()
{
if (batch_size == 0)
return;
weights_updater->update(batch_size, weights, bias, gradient_batch);
batch_size = 0;
++iter_num;
gradient_batch.assign(gradient_batch.size(), Float64{0.0});
}
void LinearModelData::predict(
ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const
{
gradient_computer->predict(container, block, arguments, weights, bias, context);
}
void LinearModelData::read(ReadBuffer & buf)
{
readBinary(bias, buf);
readBinary(weights, buf);
readBinary(iter_num, buf);
readBinary(gradient_batch, buf);
readBinary(batch_size, buf);
weights_updater->read(buf);
}
void LinearModelData::write(WriteBuffer & buf) const
{
writeBinary(bias, buf);
writeBinary(weights, buf);
writeBinary(iter_num, buf);
writeBinary(gradient_batch, buf);
writeBinary(batch_size, buf);
weights_updater->write(buf);
}
void LinearModelData::merge(const DB::LinearModelData & rhs)
{
if (iter_num == 0 && rhs.iter_num == 0)
return;
update_state();
/// can't update rhs state because it's constant
Float64 frac = (static_cast<Float64>(iter_num) * iter_num) / (iter_num * iter_num + rhs.iter_num * rhs.iter_num);
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] = weights[i] * frac + rhs.weights[i] * (1 - frac);
}
bias = bias * frac + rhs.bias * (1 - frac);
iter_num += rhs.iter_num;
weights_updater->merge(*rhs.weights_updater, frac, 1 - frac);
}
void LinearModelData::add(const IColumn ** columns, size_t row_num)
{
/// first column stores target; features start from (columns + 1)
const auto target = (*columns[0])[row_num].get<Float64>();
/// Here we have columns + 1 as first column corresponds to target value, and others - to features
weights_updater->add_to_batch(
gradient_batch, *gradient_computer, weights, bias, learning_rate, l2_reg_coef, target, columns + 1, row_num);
++batch_size;
if (batch_size == batch_capacity)
{
update_state();
}
}
void Nesterov::read(ReadBuffer & buf)
{
readBinary(accumulated_gradient, buf);
}
void Nesterov::write(WriteBuffer & buf) const
{
writeBinary(accumulated_gradient, buf);
}
void Nesterov::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
{
auto & nesterov_rhs = static_cast<const Nesterov &>(rhs);
for (size_t i = 0; i < accumulated_gradient.size(); ++i)
{
accumulated_gradient[i] = accumulated_gradient[i] * frac + nesterov_rhs.accumulated_gradient[i] * rhs_frac;
}
}
void Nesterov::update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
{
if (accumulated_gradient.empty())
{
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
}
for (size_t i = 0; i < batch_gradient.size(); ++i)
{
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size;
}
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += accumulated_gradient[i];
}
bias += accumulated_gradient[weights.size()];
}
void Nesterov::add_to_batch(
std::vector<Float64> & batch_gradient,
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num)
{
if (accumulated_gradient.empty())
{
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
}
std::vector<Float64> shifted_weights(weights.size());
for (size_t i = 0; i != shifted_weights.size(); ++i)
{
shifted_weights[i] = weights[i] + accumulated_gradient[i] * alpha_;
}
auto shifted_bias = bias + accumulated_gradient[weights.size()] * alpha_;
gradient_computer.compute(batch_gradient, shifted_weights, shifted_bias, learning_rate, l2_reg_coef, target, columns, row_num);
}
void Momentum::read(ReadBuffer & buf)
{
readBinary(accumulated_gradient, buf);
}
void Momentum::write(WriteBuffer & buf) const
{
writeBinary(accumulated_gradient, buf);
}
void Momentum::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
{
auto & momentum_rhs = static_cast<const Momentum &>(rhs);
for (size_t i = 0; i < accumulated_gradient.size(); ++i)
{
accumulated_gradient[i] = accumulated_gradient[i] * frac + momentum_rhs.accumulated_gradient[i] * rhs_frac;
}
}
void Momentum::update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
{
/// batch_size is already checked to be greater than 0
if (accumulated_gradient.empty())
{
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
}
for (size_t i = 0; i < batch_gradient.size(); ++i)
{
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size;
}
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += accumulated_gradient[i];
}
bias += accumulated_gradient[weights.size()];
}
void StochasticGradientDescent::update(
UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
{
/// batch_size is already checked to be greater than 0
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += batch_gradient[i] / batch_size;
}
bias += batch_gradient[weights.size()] / batch_size;
}
void IWeightsUpdater::add_to_batch(
std::vector<Float64> & batch_gradient,
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num)
{
gradient_computer.compute(batch_gradient, weights, bias, learning_rate, l2_reg_coef, target, columns, row_num);
}
void LogisticRegression::predict(
ColumnVector<Float64>::Container & container,
Block & block,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
const Context & context) const
{
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
/// If column type is already Float64 then castColumn simply returns it
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
if (!features_column)
{
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
}
}
container.reserve(rows_num);
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
container.emplace_back(1 / (1 + exp(-results[row_num])));
}
}
void LogisticRegression::compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num)
{
Float64 derivative = bias;
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
derivative += weights[i] * value;
}
derivative *= target;
derivative = exp(derivative);
batch_gradient[weights.size()] += learning_rate * target / (derivative + 1);
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
batch_gradient[i] += learning_rate * target * value / (derivative + 1) - 2 * l2_reg_coef * weights[i];
}
}
void LinearRegression::predict(
ColumnVector<Float64>::Container & container,
Block & block,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
const Context & context) const
{
if (weights.size() + 1 != arguments.size())
{
throw Exception("In predict function number of arguments differs from the size of weights vector", ErrorCodes::LOGICAL_ERROR);
}
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
/// If column type is already Float64 then castColumn simply returns it
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
if (!features_column)
{
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
}
}
container.reserve(rows_num);
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
container.emplace_back(results[row_num]);
}
}
void LinearRegression::compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num)
{
Float64 derivative = (target - bias);
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
derivative -= weights[i] * value;
}
derivative *= (2 * learning_rate);
batch_gradient[weights.size()] += derivative;
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
batch_gradient[i] += derivative * value - 2 * l2_reg_coef * weights[i];
}
}
}

View File

@ -0,0 +1,330 @@
#pragma once
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesNumber.h>
#include "IAggregateFunction.h"
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}
/**
GradientComputer class computes gradient according to its loss function
*/
class IGradientComputer
{
public:
IGradientComputer() {}
virtual ~IGradientComputer() = default;
/// Adds computed gradient in new point (weights, bias) to batch_gradient
virtual void compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num)
= 0;
virtual void predict(
ColumnVector<Float64>::Container & container,
Block & block,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
const Context & context) const = 0;
};
class LinearRegression : public IGradientComputer
{
public:
LinearRegression() {}
void compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
void predict(
ColumnVector<Float64>::Container & container,
Block & block,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
const Context & context) const override;
};
class LogisticRegression : public IGradientComputer
{
public:
LogisticRegression() {}
void compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
void predict(
ColumnVector<Float64>::Container & container,
Block & block,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
const Context & context) const override;
};
/**
* IWeightsUpdater class defines the way to update current weights
* and uses GradientComputer class on each iteration
*/
class IWeightsUpdater
{
public:
virtual ~IWeightsUpdater() = default;
/// Calls GradientComputer to update current mini-batch
virtual void add_to_batch(
std::vector<Float64> & batch_gradient,
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num);
/// Updates current weights according to the gradient from the last mini-batch
virtual void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
/// Used during the merge of two states
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
/// Used for serialization when necessary
virtual void write(WriteBuffer &) const {}
/// Used for serialization when necessary
virtual void read(ReadBuffer &) {}
};
class StochasticGradientDescent : public IWeightsUpdater
{
public:
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
};
class Momentum : public IWeightsUpdater
{
public:
Momentum() {}
Momentum(Float64 alpha) : alpha_(alpha) {}
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
void write(WriteBuffer & buf) const override;
void read(ReadBuffer & buf) override;
private:
Float64 alpha_{0.1};
std::vector<Float64> accumulated_gradient;
};
class Nesterov : public IWeightsUpdater
{
public:
Nesterov() {}
Nesterov(Float64 alpha) : alpha_(alpha) {}
void add_to_batch(
std::vector<Float64> & batch_gradient,
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
void write(WriteBuffer & buf) const override;
void read(ReadBuffer & buf) override;
private:
Float64 alpha_{0.1};
std::vector<Float64> accumulated_gradient;
};
/**
* LinearModelData is a class which manages current state of learning
*/
class LinearModelData
{
public:
LinearModelData() {}
LinearModelData(
Float64 learning_rate,
Float64 l2_reg_coef,
UInt32 param_num,
UInt32 batch_capacity,
std::shared_ptr<IGradientComputer> gradient_computer,
std::shared_ptr<IWeightsUpdater> weights_updater);
void add(const IColumn ** columns, size_t row_num);
void merge(const LinearModelData & rhs);
void write(WriteBuffer & buf) const;
void read(ReadBuffer & buf);
void
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
private:
std::vector<Float64> weights;
Float64 bias{0.0};
Float64 learning_rate;
Float64 l2_reg_coef;
UInt32 batch_capacity;
UInt32 iter_num = 0;
std::vector<Float64> gradient_batch;
UInt32 batch_size;
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
/**
* The function is called when we want to flush current batch and update our weights
*/
void update_state();
};
template <
/// Implemented Machine Learning method
typename Data,
/// Name of the method
typename Name>
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
{
public:
String getName() const override { return Name::name; }
explicit AggregateFunctionMLMethod(
UInt32 param_num,
std::shared_ptr<IGradientComputer> gradient_computer,
std::shared_ptr<IWeightsUpdater> weights_updater,
Float64 learning_rate,
Float64 l2_reg_coef,
UInt32 batch_size,
const DataTypes & arguments_types,
const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params)
, param_num(param_num)
, learning_rate(learning_rate)
, l2_reg_coef(l2_reg_coef)
, batch_size(batch_size)
, gradient_computer(std::move(gradient_computer))
, weights_updater(std::move(weights_updater))
{
}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
void create(AggregateDataPtr place) const override
{
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
this->data(place).add(columns, row_num);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).write(buf); }
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).read(buf); }
void predictValues(
ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments, const Context & context) const override
{
if (arguments.size() != param_num + 1)
throw Exception(
"Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size())
+ ". Required: " + std::to_string(param_num + 1),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto & column = dynamic_cast<ColumnVector<Float64> &>(to);
this->data(place).predict(column.getData(), block, arguments, context);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
std::ignore = place;
std::ignore = to;
throw std::runtime_error("not implemented");
}
const char * getHeaderFilePath() const override { return __FILE__; }
private:
UInt32 param_num;
Float64 learning_rate;
Float64 l2_reg_coef;
UInt32 batch_size;
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
};
struct NameLinearRegression
{
static constexpr auto name = "LinearRegression";
};
struct NameLogisticRegression
{
static constexpr auto name = "LogisticRegression";
};
}

View File

@ -43,8 +43,12 @@ template <typename Value, bool FloatReturn> using FuncQuantilesTDigestWeighted =
template <template <typename, bool> class Function>
static constexpr bool supportDecimal()
{
return std::is_same_v<Function<Float32, false>, FuncQuantileExact<Float32, false>> ||
std::is_same_v<Function<Float32, false>, FuncQuantilesExact<Float32, false>>;
return std::is_same_v<Function<Float32, false>, FuncQuantile<Float32, false>> ||
std::is_same_v<Function<Float32, false>, FuncQuantiles<Float32, false>> ||
std::is_same_v<Function<Float32, false>, FuncQuantileExact<Float32, false>> ||
std::is_same_v<Function<Float32, false>, FuncQuantilesExact<Float32, false>> ||
std::is_same_v<Function<Float32, false>, FuncQuantileExactWeighted<Float32, false>> ||
std::is_same_v<Function<Float32, false>, FuncQuantilesExactWeighted<Float32, false>>;
}
@ -66,9 +70,9 @@ AggregateFunctionPtr createAggregateFunctionQuantile(const std::string & name, c
if constexpr (supportDecimal<Function>())
{
if (which.idx == TypeIndex::Decimal32) return std::make_shared<Function<Decimal32, true>>(argument_type, params);
if (which.idx == TypeIndex::Decimal64) return std::make_shared<Function<Decimal64, true>>(argument_type, params);
if (which.idx == TypeIndex::Decimal128) return std::make_shared<Function<Decimal128, true>>(argument_type, params);
if (which.idx == TypeIndex::Decimal32) return std::make_shared<Function<Decimal32, false>>(argument_type, params);
if (which.idx == TypeIndex::Decimal64) return std::make_shared<Function<Decimal64, false>>(argument_type, params);
if (which.idx == TypeIndex::Decimal128) return std::make_shared<Function<Decimal128, false>>(argument_type, params);
}
throw Exception("Illegal type " + argument_type->getName() + " of argument for aggregate function " + name,

View File

@ -0,0 +1,30 @@
#include "AggregateFunctionTSGroupSum.h"
#include "AggregateFunctionFactory.h"
#include "FactoryHelpers.h"
#include "Helpers.h"
namespace DB
{
namespace
{
template <bool rate>
AggregateFunctionPtr createAggregateFunctionTSgroupSum(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
if (arguments.size() < 3)
throw Exception("Not enough event arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionTSgroupSum<rate>>(arguments);
}
}
void registerAggregateFunctionTSgroupSum(AggregateFunctionFactory & factory)
{
factory.registerFunction("TSgroupSum", createAggregateFunctionTSgroupSum<false>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("TSgroupRateSum", createAggregateFunctionTSgroupSum<true>, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -0,0 +1,287 @@
#pragma once
#include <bitset>
#include <iostream>
#include <map>
#include <queue>
#include <sstream>
#include <unordered_set>
#include <utility>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Common/ArenaAllocator.h>
#include <ext/range.h>
#include "IAggregateFunction.h"
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
}
template <bool rate>
struct AggregateFunctionTSgroupSumData
{
using DataPoint = std::pair<Int64, Float64>;
struct Points
{
using Dps = std::queue<DataPoint>;
Dps dps;
void add(Int64 t, Float64 v)
{
dps.push(std::make_pair(t, v));
if (dps.size() > 2)
dps.pop();
}
Float64 getval(Int64 t)
{
Int64 t1, t2;
Float64 v1, v2;
if (rate)
{
if (dps.size() < 2)
return 0;
t1 = dps.back().first;
t2 = dps.front().first;
v1 = dps.back().second;
v2 = dps.front().second;
return (v1 - v2) / Float64(t1 - t2);
}
else
{
if (dps.size() == 1 && t == dps.front().first)
return dps.front().second;
t1 = dps.back().first;
t2 = dps.front().first;
v1 = dps.back().second;
v2 = dps.front().second;
return v2 + ((v1 - v2) * Float64(t - t2)) / Float64(t1 - t2);
}
}
};
static constexpr size_t bytes_on_stack = 128;
typedef std::map<UInt64, Points> Series;
typedef PODArray<DataPoint, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>> AggSeries;
Series ss;
AggSeries result;
void add(UInt64 uid, Int64 t, Float64 v)
{ //suppose t is coming asc
typename Series::iterator it_ss;
if (ss.count(uid) == 0)
{ //time series not exist, insert new one
Points tmp;
tmp.add(t, v);
ss.emplace(uid, tmp);
it_ss = ss.find(uid);
}
else
{
it_ss = ss.find(uid);
it_ss->second.add(t, v);
}
if (result.size() > 0 && t < result.back().first)
throw Exception{"TSgroupSum or TSgroupRateSum must order by timestamp asc!!!", ErrorCodes::LOGICAL_ERROR};
if (result.size() > 0 && t == result.back().first)
{
//do not add new point
if (rate)
result.back().second += it_ss->second.getval(t);
else
result.back().second += v;
}
else
{
if (rate)
result.emplace_back(std::make_pair(t, it_ss->second.getval(t)));
else
result.emplace_back(std::make_pair(t, v));
}
size_t i = result.size() - 1;
//reverse find out the index of timestamp that more than previous timestamp of t
while (result[i].first > it_ss->second.dps.front().first && i >= 0)
i--;
i++;
while (i < result.size() - 1)
{
result[i].second += it_ss->second.getval(result[i].first);
i++;
}
}
void merge(const AggregateFunctionTSgroupSumData & other)
{
//if ts has overlap, then aggregate two series by interpolation;
AggSeries tmp;
tmp.reserve(other.result.size() + result.size());
size_t i = 0, j = 0;
Int64 t1, t2;
Float64 v1, v2;
while (i < result.size() && j < other.result.size())
{
if (result[i].first < other.result[j].first)
{
if (j == 0)
{
tmp.emplace_back(result[i]);
}
else
{
t1 = other.result[j].first;
t2 = other.result[j - 1].first;
v1 = other.result[j].second;
v2 = other.result[j - 1].second;
Float64 value = result[i].second + v2 + (v1 - v2) * (Float64(result[i].first - t2)) / Float64(t1 - t2);
tmp.emplace_back(std::make_pair(result[i].first, value));
}
i++;
}
else if (result[i].first > other.result[j].first)
{
if (i == 0)
{
tmp.emplace_back(other.result[j]);
}
else
{
t1 = result[i].first;
t2 = result[i - 1].first;
v1 = result[i].second;
v2 = result[i - 1].second;
Float64 value = other.result[j].second + v2 + (v1 - v2) * (Float64(other.result[j].first - t2)) / Float64(t1 - t2);
tmp.emplace_back(std::make_pair(other.result[j].first, value));
}
j++;
}
else
{
tmp.emplace_back(std::make_pair(result[i].first, result[i].second + other.result[j].second));
i++;
j++;
}
}
while (i < result.size())
{
tmp.emplace_back(result[i]);
i++;
}
while (j < other.result.size())
{
tmp.push_back(other.result[j]);
j++;
}
swap(result, tmp);
}
void serialize(WriteBuffer & buf) const
{
size_t size = result.size();
writeVarUInt(size, buf);
buf.write(reinterpret_cast<const char *>(result.data()), sizeof(result[0]));
}
void deserialize(ReadBuffer & buf)
{
size_t size = 0;
readVarUInt(size, buf);
result.resize(size);
buf.read(reinterpret_cast<char *>(result.data()), size * sizeof(result[0]));
}
};
template <bool rate>
class AggregateFunctionTSgroupSum final
: public IAggregateFunctionDataHelper<AggregateFunctionTSgroupSumData<rate>, AggregateFunctionTSgroupSum<rate>>
{
private:
public:
String getName() const override { return rate ? "TSgroupRateSum" : "TSgroupSum"; }
AggregateFunctionTSgroupSum(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionTSgroupSumData<rate>, AggregateFunctionTSgroupSum<rate>>(arguments, {})
{
if (!WhichDataType(arguments[0].get()).isUInt64())
throw Exception{"Illegal type " + arguments[0].get()->getName() + " of argument 1 of aggregate function " + getName()
+ ", must be UInt64",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!WhichDataType(arguments[1].get()).isInt64())
throw Exception{"Illegal type " + arguments[1].get()->getName() + " of argument 2 of aggregate function " + getName()
+ ", must be Int64",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!WhichDataType(arguments[2].get()).isFloat64())
throw Exception{"Illegal type " + arguments[2].get()->getName() + " of argument 3 of aggregate function " + getName()
+ ", must be Float64",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
DataTypePtr getReturnType() const override
{
auto datatypes = std::vector<DataTypePtr>();
datatypes.push_back(std::make_shared<DataTypeInt64>());
datatypes.push_back(std::make_shared<DataTypeFloat64>());
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(datatypes));
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
auto uid = static_cast<const ColumnVector<UInt64> *>(columns[0])->getData()[row_num];
auto ts = static_cast<const ColumnVector<Int64> *>(columns[1])->getData()[row_num];
auto val = static_cast<const ColumnVector<Float64> *>(columns[2])->getData()[row_num];
if (uid && ts && val)
{
this->data(place).add(uid, ts, val);
}
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).serialize(buf); }
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).deserialize(buf); }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const auto & value = this->data(place).result;
size_t size = value.size();
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
size_t old_size = offsets_to.back();
offsets_to.push_back(offsets_to.back() + size);
if (size)
{
typename ColumnInt64::Container & ts_to
= static_cast<ColumnInt64 &>(static_cast<ColumnTuple &>(arr_to.getData()).getColumn(0)).getData();
typename ColumnFloat64::Container & val_to
= static_cast<ColumnFloat64 &>(static_cast<ColumnTuple &>(arr_to.getData()).getColumn(1)).getData();
ts_to.reserve(old_size + size);
val_to.reserve(old_size + size);
size_t i = 0;
while (i < this->data(place).result.size())
{
ts_to.push_back(this->data(place).result[i].first);
val_to.push_back(this->data(place).result[i].second);
i++;
}
}
}
bool allocatesMemoryInArena() const override { return true; }
const char * getHeaderFilePath() const override { return __FILE__; }
};
}

View File

@ -7,6 +7,8 @@
#include <Core/Types.h>
#include <Core/Field.h>
#include <Core/ColumnNumbers.h>
#include <Core/Block.h>
#include <Common/Exception.h>
@ -92,6 +94,13 @@ public:
/// Inserts results into a column.
virtual void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const = 0;
/// This function is used for machine learning methods
virtual void predictValues(ConstAggregateDataPtr /* place */, IColumn & /*to*/,
Block & /*block*/, const ColumnNumbers & /*arguments*/, const Context & /*context*/) const
{
throw Exception("Method predictValues is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** Returns true for aggregate functions of type -State.
* They are executed as other aggregate functions, but not finalized (return an aggregation state that can be combined with another).
*/
@ -149,7 +158,6 @@ protected:
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
public:
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}

View File

@ -20,12 +20,22 @@ namespace ErrorCodes
template <typename Value>
struct QuantileExactWeighted
{
struct Int128Hash
{
size_t operator()(Int128 x) const
{
return CityHash_v1_0_2::Hash128to64({x >> 64, x & 0xffffffffffffffffll});
}
};
using Weight = UInt64;
using UnderlyingType = typename NativeType<Value>::Type;
using Hasher = std::conditional_t<std::is_same_v<Value, Decimal128>, Int128Hash, HashCRC32<UnderlyingType>>;
/// When creating, the hash table must be small.
using Map = HashMap<
Value, Weight,
HashCRC32<Value>,
UnderlyingType, Weight,
Hasher,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
>;
@ -39,7 +49,7 @@ struct QuantileExactWeighted
++map[x];
}
void add(const Value & x, const Weight & weight)
void add(const Value & x, Weight weight)
{
if (!isNaN(x))
map[x] += weight;

View File

@ -28,6 +28,7 @@ void registerAggregateFunctionTopK(AggregateFunctionFactory &);
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
@ -40,7 +41,7 @@ void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory
void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory);
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory);
void registerAggregateFunctionTSgroupSum(AggregateFunctionFactory & factory);
void registerAggregateFunctions()
{
{
@ -69,6 +70,8 @@ void registerAggregateFunctions()
registerAggregateFunctionsMaxIntersections(factory);
registerAggregateFunctionHistogram(factory);
registerAggregateFunctionRetention(factory);
registerAggregateFunctionTSgroupSum(factory);
registerAggregateFunctionMLMethod(factory);
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionLeastSqr(factory);
}

View File

@ -12,8 +12,7 @@
#include <Core/Protocol.h>
#include <Core/QueryProcessingStage.h>
#include <DataStreams/IBlockInputStream.h>
#include <DataStreams/IBlockOutputStream.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <DataStreams/BlockStreamProfileInfo.h>
#include <IO/ConnectionTimeouts.h>

View File

@ -10,6 +10,7 @@
#include <Common/typeid_cast.h>
#include <Common/Arena.h>
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
namespace DB
{
@ -18,6 +19,7 @@ namespace ErrorCodes
{
extern const int PARAMETER_OUT_OF_BOUND;
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
@ -33,6 +35,25 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}
/// This function is used in convertToValues() and predictValues()
/// and is written here to avoid repetitions
bool ColumnAggregateFunction::tryFinalizeAggregateFunction(MutableColumnPtr *res_) const
{
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
*res_ = std::move(res);
return true;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
*res_ = std::move(res);
return false;
}
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
{
/** If the aggregate function returns an unfinalized/unfinished state,
@ -65,23 +86,46 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
/** Convertion function is used in convertToValues and predictValues
* in the similar part of both functions
*/
MutableColumnPtr res;
if (tryFinalizeAggregateFunction(&res))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
for (auto val : data)
func->insertResultInto(val, *res);
return res;
}
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
{
MutableColumnPtr res;
tryFinalizeAggregateFunction(&res);
auto ML_function = func.get();
if (ML_function)
{
size_t row_num = 0;
for (auto val : data)
{
ML_function->predictValues(val, *res, block, arguments, context);
++row_num;
}
}
else
{
throw Exception("Illegal aggregate function is passed",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return res;
}
void ColumnAggregateFunction::ensureOwnership()
{

View File

@ -10,6 +10,7 @@
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <Functions/FunctionHelpers.h>
namespace DB
{
@ -117,6 +118,9 @@ public:
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
const char * getFamilyName() const override { return "AggregateFunction"; }
bool tryFinalizeAggregateFunction(MutableColumnPtr* res_) const;
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const;
size_t size() const override
{
return getData().size();

View File

@ -570,6 +570,8 @@ ConfigProcessor::LoadedConfig ConfigProcessor::loadConfigWithZooKeeperIncludes(
}
void ConfigProcessor::savePreprocessedConfig(const LoadedConfig & loaded_config, std::string preprocessed_dir)
{
try
{
if (preprocessed_path.empty())
{
@ -604,8 +606,6 @@ void ConfigProcessor::savePreprocessedConfig(const LoadedConfig & loaded_config,
if (!preprocessed_path_parent.toString().empty())
Poco::File(preprocessed_path_parent).createDirectories();
}
try
{
DOMWriter().writeNode(preprocessed_path, loaded_config.preprocessed_xml);
}
catch (Poco::Exception & e)

View File

@ -1,6 +1,7 @@
#pragma once
#include <Common/UTF8Helpers.h>
#include <Core/Defines.h>
#include <ext/range.h>
#include <Poco/UTF8Encoding.h>
#include <Poco/Unicode.h>

View File

@ -1430,6 +1430,8 @@ void ZooKeeper::pushRequest(RequestInfo && info)
if (!info.request->xid)
{
info.request->xid = next_xid.fetch_add(1);
if (info.request->xid == close_xid)
throw Exception("xid equal to close_xid", ZSESSIONEXPIRED);
if (info.request->xid < 0)
throw Exception("XID overflow", ZSESSIONEXPIRED);
}

View File

@ -25,6 +25,8 @@
#cmakedefine01 USE_BROTLI
#cmakedefine01 USE_SSL
#cmakedefine01 USE_HYPERSCAN
#cmakedefine01 USE_SIMDJSON
#cmakedefine01 USE_RAPIDJSON
#cmakedefine01 USE_LFALLOC
#cmakedefine01 USE_LFALLOC_RANDOM_HINT

View File

@ -1,14 +1,18 @@
include(${ClickHouse_SOURCE_DIR}/cmake/dbms_glob_sources.cmake)
add_headers_and_sources(clickhouse_compression .)
add_library(clickhouse_compression ${clickhouse_compression_headers} ${clickhouse_compression_sources})
target_link_libraries(clickhouse_compression PRIVATE clickhouse_parsers clickhouse_common_io ${ZSTD_LIBRARY} ${LZ4_LIBRARY} ${CITYHASH_LIBRARIES})
target_link_libraries(clickhouse_compression PRIVATE clickhouse_parsers clickhouse_common_io ${LZ4_LIBRARY} ${CITYHASH_LIBRARIES})
if(ZSTD_LIBRARY)
target_link_libraries(clickhouse_compression PRIVATE ${ZSTD_LIBRARY})
endif()
target_include_directories(clickhouse_compression PUBLIC ${DBMS_INCLUDE_DIR})
target_include_directories(clickhouse_compression SYSTEM PUBLIC ${PCG_RANDOM_INCLUDE_DIR})
if (NOT USE_INTERNAL_LZ4_LIBRARY)
target_include_directories(clickhouse_compression SYSTEM BEFORE PRIVATE ${LZ4_INCLUDE_DIR})
endif ()
if (NOT USE_INTERNAL_ZSTD_LIBRARY)
if (NOT USE_INTERNAL_ZSTD_LIBRARY AND ZSTD_INCLUDE_DIR)
target_include_directories(clickhouse_compression SYSTEM BEFORE PRIVATE ${ZSTD_INCLUDE_DIR})
endif ()

View File

@ -430,6 +430,13 @@ inline bool_if_safe_conversion<A, B> greaterOrEqualsOp(A a, B b)
template <typename From, typename To>
inline bool NO_SANITIZE_UNDEFINED convertNumeric(From value, To & result)
{
/// If the type is actually the same it's not necessary to do any checks.
if constexpr (std::is_same_v<From, To>)
{
result = value;
return true;
}
/// Note that NaNs doesn't compare equal to anything, but they are still in range of any Float type.
if (isNaN(value) && std::is_floating_point_v<To>)
{

View File

@ -123,3 +123,7 @@
#else
#define OPTIMIZE(x)
#endif
/// This number is only used for distributed version compatible.
/// It could be any magic number.
#define DBMS_DISTRIBUTED_SENDS_MAGIC_NUMBER 0xCAFECABE

View File

@ -109,5 +109,4 @@ void Settings::addProgramOptions(boost::program_options::options_description & o
Settings::getDescription(index).data)));
}
}
}

View File

@ -50,6 +50,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, min_insert_block_size_rows, DEFAULT_INSERT_BLOCK_SIZE, "Squash blocks passed to INSERT query to specified size in rows, if blocks are not big enough.") \
M(SettingUInt64, min_insert_block_size_bytes, (DEFAULT_INSERT_BLOCK_SIZE * 256), "Squash blocks passed to INSERT query to specified size in bytes, if blocks are not big enough.") \
M(SettingMaxThreads, max_threads, 0, "The maximum number of threads to execute the request. By default, it is determined automatically.") \
M(SettingMaxThreads, max_alter_threads, 0, "The maximum number of threads to execute the ALTER requests. By default, it is determined automatically.") \
M(SettingUInt64, max_read_buffer_size, DBMS_DEFAULT_BUFFER_SIZE, "The maximum size of the buffer to read from the filesystem.") \
M(SettingUInt64, max_distributed_connections, 1024, "The maximum number of connections for distributed processing of one query (should be greater than max_threads).") \
M(SettingUInt64, max_query_size, 262144, "Which part of the query can be read into RAM for parsing (the remaining data for INSERT, if any, is read later)") \
@ -205,6 +206,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, insert_distributed_timeout, 0, "Timeout for insert query into distributed. Setting is used only with insert_distributed_sync enabled. Zero value means no timeout.") \
M(SettingInt64, distributed_ddl_task_timeout, 180, "Timeout for DDL query responses from all hosts in cluster. Negative value means infinite.") \
M(SettingMilliseconds, stream_flush_interval_ms, 7500, "Timeout for flushing data from streaming storages.") \
M(SettingMilliseconds, stream_poll_timeout_ms, 500, "Timeout for polling data from streaming storages.") \
M(SettingString, format_schema, "", "Schema identifier (used by schema-based formats)") \
M(SettingBool, insert_allow_materialized_columns, 0, "If setting is enabled, Allow materialized columns in INSERT.") \
M(SettingSeconds, http_connection_timeout, DEFAULT_HTTP_READ_BUFFER_CONNECTION_TIMEOUT, "HTTP connection timeout.") \
@ -322,8 +324,9 @@ struct Settings : public SettingsCollection<Settings>
M(SettingBool, allow_experimental_data_skipping_indices, false, "If it is set to true, data skipping indices can be used in CREATE TABLE/ALTER TABLE queries.") \
\
M(SettingBool, allow_hyperscan, true, "Allow functions that use Hyperscan library. Disable to avoid potentially long compilation times and excessive resource usage.") \
M(SettingBool, allow_simdjson, 1, "Allow using simdjson library in 'JSON*' functions if AVX2 instructions are available. If disabled rapidjson will be used.") \
\
M(SettingUInt64, max_partitions_per_insert_block, 100, "Limit maximum number of partitions in single INSERTed block. Zero means unlimited. Throw exception if the block contains too many partitions. This setting is a safety threshold, because using large number of partitions is a common misconception.") \
M(SettingUInt64, max_partitions_per_insert_block, 100, "Limit maximum number of partitions in single INSERTed block. Zero means unlimited. Throw exception if the block contains too many partitions. This setting is a safety threshold, because using large number of partitions is a common misconception.")
DECLARE_SETTINGS_COLLECTION(LIST_OF_SETTINGS)

View File

@ -428,7 +428,7 @@ public:
const const_reference & operator *() const { return ref; }
const const_reference * operator ->() const { return &ref; }
const_iterator & operator ++() { ++ref.member; return *this; }
const_iterator & operator ++(int) { const_iterator tmp = *this; ++*this; return tmp; }
const_iterator operator ++(int) { const_iterator tmp = *this; ++*this; return tmp; }
bool operator ==(const const_iterator & rhs) const { return ref.member == rhs.ref.member && ref.collection == rhs.ref.collection; }
bool operator !=(const const_iterator & rhs) const { return !(*this == rhs); }
protected:
@ -445,7 +445,7 @@ public:
reference & operator *() const { return this->ref; }
reference * operator ->() const { return &this->ref; }
iterator & operator ++() { const_iterator::operator ++(); return *this; }
iterator & operator ++(int) { iterator tmp = *this; ++*this; return tmp; }
iterator operator ++(int) { iterator tmp = *this; ++*this; return tmp; }
};
/// Returns the number of settings.

View File

@ -165,6 +165,11 @@ template <> constexpr bool IsDecimalNumber<Decimal32> = true;
template <> constexpr bool IsDecimalNumber<Decimal64> = true;
template <> constexpr bool IsDecimalNumber<Decimal128> = true;
template <typename T> struct NativeType { using Type = T; };
template <> struct NativeType<Decimal32> { using Type = Int32; };
template <> struct NativeType<Decimal64> { using Type = Int64; };
template <> struct NativeType<Decimal128> { using Type = Int128; };
}
/// Specialization of `std::hash` for the Decimal<T> types.

View File

@ -1,6 +1,8 @@
#include <DataStreams/AggregatingSortedBlockInputStream.h>
#include <Common/typeid_cast.h>
#include <Common/StringUtils/StringUtils.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
namespace DB
@ -22,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
ColumnWithTypeAndName & column = header.safeGetByPosition(i);
/// We leave only states of aggregate functions.
if (!startsWith(column.type->getName(), "AggregateFunction"))
if (!dynamic_cast<const DataTypeAggregateFunction *>(column.type.get()) && !dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
{
column_numbers_not_to_aggregate.push_back(i);
continue;
@ -40,9 +42,19 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
continue;
}
if (auto simple_aggr = dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
{
// simple aggregate function
SimpleAggregateDescription desc{simple_aggr->getFunction(), i};
columns_to_simple_aggregate.emplace_back(std::move(desc));
}
else
{
// standard aggregate function
column_numbers_to_aggregate.push_back(i);
}
}
}
Block AggregatingSortedBlockInputStream::readImpl()
@ -91,7 +103,11 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
/// if there are enough rows accumulated and the last one is calculated completely
if (key_differs && merged_rows >= max_block_size)
{
/// Write the simple aggregation result for the previous group.
insertSimpleAggregationResult(merged_columns);
return;
}
queue.pop();
@ -110,6 +126,14 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
for (auto & column_to_aggregate : columns_to_aggregate)
column_to_aggregate->insertDefault();
/// Write the simple aggregation result for the previous group.
if (merged_rows > 0)
insertSimpleAggregationResult(merged_columns);
/// Reset simple aggregation states for next row
for (auto & desc : columns_to_simple_aggregate)
desc.createState();
++merged_rows;
}
@ -127,6 +151,9 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
}
}
/// Write the simple aggregation result for the previous group.
insertSimpleAggregationResult(merged_columns);
finished = true;
}
@ -138,6 +165,21 @@ void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor)
size_t j = column_numbers_to_aggregate[i];
columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos);
}
for (auto & desc : columns_to_simple_aggregate)
{
auto & col = cursor->all_columns[desc.column_number];
desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr);
}
}
void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns)
{
for (auto & desc : columns_to_simple_aggregate)
{
desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]);
desc.destroyState();
}
}
}

View File

@ -7,6 +7,7 @@
#include <DataStreams/MergingSortedBlockInputStream.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/AlignedBuffer.h>
namespace DB
@ -38,10 +39,13 @@ private:
/// Read finished.
bool finished = false;
struct SimpleAggregateDescription;
/// Columns with which numbers should be aggregated.
ColumnNumbers column_numbers_to_aggregate;
ColumnNumbers column_numbers_not_to_aggregate;
std::vector<ColumnAggregateFunction *> columns_to_aggregate;
std::vector<SimpleAggregateDescription> columns_to_simple_aggregate;
RowRef current_key; /// The current primary key.
RowRef next_key; /// The primary key of the next row.
@ -54,6 +58,53 @@ private:
/** Extract all states of aggregate functions and merge them with the current group.
*/
void addRow(SortCursor & cursor);
/** Insert all values of current row for simple aggregate functions
*/
void insertSimpleAggregationResult(MutableColumns & merged_columns);
/// Stores information for aggregation of SimpleAggregateFunction columns
struct SimpleAggregateDescription
{
/// An aggregate function 'anyLast', 'sum'...
AggregateFunctionPtr function;
IAggregateFunction::AddFunc add_function;
size_t column_number;
AlignedBuffer state;
bool created = false;
SimpleAggregateDescription(const AggregateFunctionPtr & function_, const size_t column_number_) : function(function_), column_number(column_number_)
{
add_function = function->getAddressOfAddFunction();
state.reset(function->sizeOfData(), function->alignOfData());
}
void createState()
{
if (created)
return;
function->create(state.data());
created = true;
}
void destroyState()
{
if (!created)
return;
function->destroy(state.data());
created = false;
}
/// Explicitly destroy aggregation state if the stream is terminated
~SimpleAggregateDescription()
{
destroyState();
}
SimpleAggregateDescription() = default;
SimpleAggregateDescription(SimpleAggregateDescription &&) = default;
SimpleAggregateDescription(const SimpleAggregateDescription &) = delete;
};
};
}

View File

@ -1,7 +1,8 @@
#pragma once
#include <DataStreams/IBlockInputStream.h>
#include <DataStreams/IBlockOutputStream.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <functional>
namespace DB

View File

@ -1,8 +1,10 @@
#pragma once
#include <vector>
#include <Common/Stopwatch.h>
#include <Core/Types.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Common/Stopwatch.h>
#include <vector>
namespace DB
{
@ -10,7 +12,6 @@ namespace DB
class Block;
class ReadBuffer;
class WriteBuffer;
class IBlockInputStream;
/// Information for profiling. See IBlockInputStream.h
struct BlockStreamProfileInfo

View File

@ -9,6 +9,8 @@ ExpressionBlockInputStream::ExpressionBlockInputStream(const BlockInputStreamPtr
: expression(expression_)
{
children.push_back(input);
cached_header = children.back()->getHeader();
expression->execute(cached_header, true);
}
String ExpressionBlockInputStream::getName() const { return "Expression"; }
@ -23,9 +25,7 @@ Block ExpressionBlockInputStream::getTotals()
Block ExpressionBlockInputStream::getHeader() const
{
Block res = children.back()->getHeader();
expression->execute(res, true);
return res;
return cached_header.cloneEmpty();
}
Block ExpressionBlockInputStream::readImpl()

View File

@ -30,6 +30,7 @@ protected:
private:
ExpressionActionsPtr expression;
Block cached_header;
};
}

View File

@ -2,6 +2,7 @@
#include <Core/Block.h>
#include <Core/SortDescription.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <DataStreams/BlockStreamProfileInfo.h>
#include <DataStreams/SizeLimits.h>
#include <IO/Progress.h>
@ -21,14 +22,10 @@ namespace ErrorCodes
extern const int QUERY_WAS_CANCELLED;
}
class IBlockInputStream;
class ProcessListElement;
class QuotaForIntervals;
class QueryStatus;
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
using BlockInputStreams = std::vector<BlockInputStreamPtr>;
/** Callback to track the progress of the query.
* Used in IBlockInputStream and Context.
* The function takes the number of rows in the last block, the number of bytes in the last block.
@ -269,6 +266,11 @@ protected:
children.push_back(child);
}
/** Check limits.
* But only those that can be checked within each separate stream.
*/
bool checkTimeLimit();
private:
bool enabled_extremes = false;
@ -296,10 +298,9 @@ private:
void updateExtremes(Block & block);
/** Check limits and quotas.
/** Check quotas.
* But only those that can be checked within each separate stream.
*/
bool checkTimeLimit();
void checkQuota(Block & block);
size_t checkDepthImpl(size_t max_depth, size_t level) const;

View File

@ -1,11 +1,14 @@
#pragma once
#include <Core/Block.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Storages/TableStructureLockHolder.h>
#include <boost/noncopyable.hpp>
#include <memory>
#include <string>
#include <vector>
#include <memory>
#include <boost/noncopyable.hpp>
#include <Core/Block.h>
#include <Storages/TableStructureLockHolder.h>
namespace DB
@ -64,6 +67,4 @@ private:
std::vector<TableStructureReadLockHolder> table_locks;
};
using BlockOutputStreamPtr = std::shared_ptr<IBlockOutputStream>;
}

View File

@ -0,0 +1,16 @@
#pragma once
#include <memory>
#include <vector>
namespace DB
{
class IBlockInputStream;
class IBlockOutputStream;
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
using BlockInputStreams = std::vector<BlockInputStreamPtr>;
using BlockOutputStreamPtr = std::shared_ptr<IBlockOutputStream>;
}

View File

@ -8,7 +8,7 @@
#include <common/logger_useful.h>
#include <DataStreams/IBlockInputStream.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Common/setThreadName.h>
#include <Common/CurrentMetrics.h>
#include <Common/CurrentThread.h>

View File

@ -1,5 +1,7 @@
#pragma once
#include <DataStreams/IBlockStream_fwd.h>
#include <atomic>
#include <functional>
@ -7,8 +9,6 @@
namespace DB
{
class IBlockInputStream;
class IBlockOutputStream;
class Block;
/** Copies data from the InputStream into the OutputStream

View File

@ -1,6 +1,8 @@
#pragma once
#include <memory>
#include <cstddef>
#include <Core/Types.h>
namespace DB
{
@ -10,21 +12,21 @@ class WriteBuffer;
struct FormatSettings;
class IColumn;
/** Further refinment of the properties of data type.
*
* Contains methods for serialization/deserialization.
* Implementations of this interface represent a data type domain (example: IPv4)
* which is a refinement of the exsitgin type with a name and specific text
* representation.
*
* IDataTypeDomain is totally immutable object. You can always share them.
/** Allow to customize an existing data type and set a different name and/or text serialization/deserialization methods.
* See use in IPv4 and IPv6 data types, and also in SimpleAggregateFunction.
*/
class IDataTypeDomain
class IDataTypeCustomName
{
public:
virtual ~IDataTypeDomain() {}
virtual ~IDataTypeCustomName() {}
virtual const char* getName() const = 0;
virtual String getName() const = 0;
};
class IDataTypeCustomTextSerialization
{
public:
virtual ~IDataTypeCustomTextSerialization() {}
/** Text serialization for displaying on a terminal or saving into a text file, and the like.
* Without escaping or quoting.
@ -56,4 +58,31 @@ public:
virtual void serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const = 0;
};
using DataTypeCustomNamePtr = std::unique_ptr<const IDataTypeCustomName>;
using DataTypeCustomTextSerializationPtr = std::unique_ptr<const IDataTypeCustomTextSerialization>;
/** Describe a data type customization
*/
struct DataTypeCustomDesc
{
DataTypeCustomNamePtr name;
DataTypeCustomTextSerializationPtr text_serialization;
DataTypeCustomDesc(DataTypeCustomNamePtr name_, DataTypeCustomTextSerializationPtr text_serialization_)
: name(std::move(name_)), text_serialization(std::move(text_serialization_)) {}
};
using DataTypeCustomDescPtr = std::unique_ptr<DataTypeCustomDesc>;
/** A simple implementation of IDataTypeCustomName
*/
class DataTypeCustomFixedName : public IDataTypeCustomName
{
private:
String name;
public:
DataTypeCustomFixedName(String name_) : name(name_) {}
String getName() const override { return name; }
};
} // namespace DB

View File

@ -1,9 +1,9 @@
#include <Columns/ColumnsNumber.h>
#include <Common/Exception.h>
#include <Common/formatIPv6.h>
#include <DataTypes/DataTypeDomainWithSimpleSerialization.h>
#include <DataTypes/DataTypeCustomSimpleTextSerialization.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/IDataTypeDomain.h>
#include <DataTypes/DataTypeCustom.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionsCoding.h>
@ -20,20 +20,15 @@ namespace ErrorCodes
namespace
{
class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization
class DataTypeCustomIPv4Serialization : public DataTypeCustomSimpleTextSerialization
{
public:
const char * getName() const override
{
return "IPv4";
}
void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override
{
const auto col = checkAndGetColumn<ColumnUInt32>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv4 type can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -48,7 +43,7 @@ public:
ColumnUInt32 * col = typeid_cast<ColumnUInt32 *>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv4 type can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -63,20 +58,16 @@ public:
}
};
class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization
class DataTypeCustomIPv6Serialization : public DataTypeCustomSimpleTextSerialization
{
public:
const char * getName() const override
{
return "IPv6";
}
void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override
{
const auto col = checkAndGetColumn<ColumnFixedString>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv6 type domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -91,7 +82,7 @@ public:
ColumnFixedString * col = typeid_cast<ColumnFixedString *>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("IPv6 type domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -100,7 +91,7 @@ public:
std::string ipv6_value(IPV6_BINARY_LENGTH, '\0');
if (!parseIPv6(buffer, reinterpret_cast<unsigned char *>(ipv6_value.data())))
{
throw Exception(String("Invalid ") + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
throw Exception("Invalid IPv6 value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
}
col->insertString(ipv6_value);
@ -111,8 +102,17 @@ public:
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory)
{
factory.registerDataTypeDomain("UInt32", std::make_unique<DataTypeDomainIPv4>());
factory.registerDataTypeDomain("FixedString(16)", std::make_unique<DataTypeDomainIPv6>());
factory.registerSimpleDataTypeCustom("IPv4", []
{
return std::make_pair(DataTypeFactory::instance().get("UInt32"),
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv4"), std::make_unique<DataTypeCustomIPv4Serialization>()));
});
factory.registerSimpleDataTypeCustom("IPv6", []
{
return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"),
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv6"), std::make_unique<DataTypeCustomIPv6Serialization>()));
});
}
} // namespace DB

View File

@ -0,0 +1,137 @@
#include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h>
#include <IO/ReadHelpers.h>
#include <Columns/ColumnAggregateFunction.h>
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFactory.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTIdentifier.h>
#include <boost/algorithm/string/join.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int SYNTAX_ERROR;
extern const int BAD_ARGUMENTS;
extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int LOGICAL_ERROR;
}
static const std::vector<String> supported_functions{"any", "anyLast", "min", "max", "sum"};
String DataTypeCustomSimpleAggregateFunction::getName() const
{
std::stringstream stream;
stream << "SimpleAggregateFunction(" << function->getName();
if (!parameters.empty())
{
stream << "(";
for (size_t i = 0; i < parameters.size(); ++i)
{
if (i)
stream << ", ";
stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]);
}
stream << ")";
}
for (const auto & argument_type : argument_types)
stream << ", " << argument_type->getName();
stream << ")";
return stream.str();
}
static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & arguments)
{
String function_name;
AggregateFunctionPtr function;
DataTypes argument_types;
Array params_row;
if (!arguments || arguments->children.empty())
throw Exception("Data type SimpleAggregateFunction requires parameters: "
"name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (const ASTFunction * parametric = arguments->children[0]->as<ASTFunction>())
{
if (parametric->parameters)
throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR);
function_name = parametric->name;
const ASTs & parameters = parametric->arguments->as<ASTExpressionList &>().children;
params_row.resize(parameters.size());
for (size_t i = 0; i < parameters.size(); ++i)
{
const ASTLiteral * lit = parameters[i]->as<ASTLiteral>();
if (!lit)
throw Exception("Parameters to aggregate functions must be literals",
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
params_row[i] = lit->value;
}
}
else if (auto opt_name = getIdentifierName(arguments->children[0]))
{
function_name = *opt_name;
}
else if (arguments->children[0]->as<ASTLiteral>())
{
throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function",
ErrorCodes::BAD_ARGUMENTS);
}
else
throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function.",
ErrorCodes::BAD_ARGUMENTS);
for (size_t i = 1; i < arguments->children.size(); ++i)
argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i]));
if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
// check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))
{
throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","),
ErrorCodes::BAD_ARGUMENTS);
}
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
if (!function->getReturnType()->equals(*removeLowCardinality(storage_type)))
{
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(),
ErrorCodes::BAD_ARGUMENTS);
}
DataTypeCustomNamePtr custom_name = std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, argument_types, params_row);
return std::make_pair(storage_type, std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
}
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory)
{
factory.registerDataTypeCustom("SimpleAggregateFunction", create);
}
}

View File

@ -0,0 +1,42 @@
#pragma once
#include <DataTypes/DataTypeCustom.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Common/FieldVisitors.h>
#include <IO/ReadHelpers.h>
namespace DB
{
/** The type SimpleAggregateFunction(fct, type) is meant to be used in an AggregatingMergeTree. It behaves like a standard
* data type but when rows are merged, an aggregation function is applied.
*
* The aggregation function is limited to simple functions whose merge state is the final result:
* any, anyLast, min, max, sum
*
* Examples:
*
* SimpleAggregateFunction(sum, Nullable(Float64))
* SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String)))
* SimpleAggregateFunction(anyLast, IPv4)
*
* Technically, a standard IDataType is instanciated and customized with IDataTypeCustomName and DataTypeCustomDesc.
*/
class DataTypeCustomSimpleAggregateFunction : public IDataTypeCustomName
{
private:
const AggregateFunctionPtr function;
const DataTypes argument_types;
const Array parameters;
public:
DataTypeCustomSimpleAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_)
: function(function_), argument_types(argument_types_), parameters(parameters_) {}
const AggregateFunctionPtr getFunction() const { return function; }
String getName() const override;
};
}

View File

@ -1,4 +1,4 @@
#include <DataTypes/DataTypeDomainWithSimpleSerialization.h>
#include <DataTypes/DataTypeCustomSimpleTextSerialization.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
@ -9,7 +9,7 @@ namespace
{
using namespace DB;
static String serializeToString(const DataTypeDomainWithSimpleSerialization & domain, const IColumn & column, size_t row_num, const FormatSettings & settings)
static String serializeToString(const DataTypeCustomSimpleTextSerialization & domain, const IColumn & column, size_t row_num, const FormatSettings & settings)
{
WriteBufferFromOwnString buffer;
domain.serializeText(column, row_num, buffer, settings);
@ -17,7 +17,7 @@ static String serializeToString(const DataTypeDomainWithSimpleSerialization & do
return buffer.str();
}
static void deserializeFromString(const DataTypeDomainWithSimpleSerialization & domain, IColumn & column, const String & s, const FormatSettings & settings)
static void deserializeFromString(const DataTypeCustomSimpleTextSerialization & domain, IColumn & column, const String & s, const FormatSettings & settings)
{
ReadBufferFromString istr(s);
domain.deserializeText(column, istr, settings);
@ -28,59 +28,59 @@ static void deserializeFromString(const DataTypeDomainWithSimpleSerialization &
namespace DB
{
DataTypeDomainWithSimpleSerialization::~DataTypeDomainWithSimpleSerialization()
DataTypeCustomSimpleTextSerialization::~DataTypeCustomSimpleTextSerialization()
{
}
void DataTypeDomainWithSimpleSerialization::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
writeEscapedString(serializeToString(*this, column, row_num, settings), ostr);
}
void DataTypeDomainWithSimpleSerialization::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
String str;
readEscapedString(str, istr);
deserializeFromString(*this, column, str, settings);
}
void DataTypeDomainWithSimpleSerialization::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
writeQuotedString(serializeToString(*this, column, row_num, settings), ostr);
}
void DataTypeDomainWithSimpleSerialization::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
String str;
readQuotedString(str, istr);
deserializeFromString(*this, column, str, settings);
}
void DataTypeDomainWithSimpleSerialization::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
writeCSVString(serializeToString(*this, column, row_num, settings), ostr);
}
void DataTypeDomainWithSimpleSerialization::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
String str;
readCSVString(str, istr, settings.csv);
deserializeFromString(*this, column, str, settings);
}
void DataTypeDomainWithSimpleSerialization::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
writeJSONString(serializeToString(*this, column, row_num, settings), ostr, settings);
}
void DataTypeDomainWithSimpleSerialization::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
String str;
readJSONString(str, istr);
deserializeFromString(*this, column, str, settings);
}
void DataTypeDomainWithSimpleSerialization::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
void DataTypeCustomSimpleTextSerialization::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
writeXMLString(serializeToString(*this, column, row_num, settings), ostr);
}

View File

@ -1,6 +1,6 @@
#pragma once
#include <DataTypes/IDataTypeDomain.h>
#include <DataTypes/DataTypeCustom.h>
namespace DB
{
@ -10,12 +10,12 @@ class WriteBuffer;
struct FormatSettings;
class IColumn;
/** Simple DataTypeDomain that uses serializeText/deserializeText
/** Simple IDataTypeCustomTextSerialization that uses serializeText/deserializeText
* for all serialization and deserialization. */
class DataTypeDomainWithSimpleSerialization : public IDataTypeDomain
class DataTypeCustomSimpleTextSerialization : public IDataTypeCustomTextSerialization
{
public:
virtual ~DataTypeDomainWithSimpleSerialization() override;
virtual ~DataTypeCustomSimpleTextSerialization() override;
// Methods that subclasses must override in order to get full serialization/deserialization support.
virtual void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override = 0;

View File

@ -1,5 +1,5 @@
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/IDataTypeDomain.h>
#include <DataTypes/DataTypeCustom.h>
#include <Parsers/parseQuery.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/ASTFunction.h>
@ -115,19 +115,23 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator
}, case_sensitiveness);
}
void DataTypeFactory::registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness)
void DataTypeFactory::registerDataTypeCustom(const String & family_name, CreatorWithCustom creator, CaseSensitiveness case_sensitiveness)
{
all_domains.reserve(all_domains.size() + 1);
auto data_type = get(type_name);
setDataTypeDomain(*data_type, *domain);
registerDataType(domain->getName(), [data_type](const ASTPtr & /*ast*/)
registerDataType(family_name, [creator](const ASTPtr & ast)
{
return data_type;
auto res = creator(ast);
res.first->setCustomization(std::move(res.second));
return res.first;
}, case_sensitiveness);
}
all_domains.emplace_back(std::move(domain));
void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness)
{
registerDataTypeCustom(name, [creator](const ASTPtr & /*ast*/)
{
return creator();
}, case_sensitiveness);
}
const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const
@ -153,11 +157,6 @@ const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
}
void DataTypeFactory::setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain)
{
data_type.setDomain(&domain);
}
void registerDataTypeNumbers(DataTypeFactory & factory);
void registerDataTypeDecimal(DataTypeFactory & factory);
void registerDataTypeDate(DataTypeFactory & factory);
@ -175,6 +174,7 @@ void registerDataTypeNested(DataTypeFactory & factory);
void registerDataTypeInterval(DataTypeFactory & factory);
void registerDataTypeLowCardinality(DataTypeFactory & factory);
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory);
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory);
DataTypeFactory::DataTypeFactory()
@ -196,6 +196,7 @@ DataTypeFactory::DataTypeFactory()
registerDataTypeInterval(*this);
registerDataTypeLowCardinality(*this);
registerDataTypeDomainIPv4AndIPv6(*this);
registerDataTypeDomainSimpleAggregateFunction(*this);
}
DataTypeFactory::~DataTypeFactory()

View File

@ -17,9 +17,6 @@ namespace DB
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
class IDataTypeDomain;
using DataTypeDomainPtr = std::unique_ptr<const IDataTypeDomain>;
/** Creates a data type by name of data type family and parameters.
*/
@ -28,6 +25,8 @@ class DataTypeFactory final : public ext::singleton<DataTypeFactory>, public IFa
private:
using SimpleCreator = std::function<DataTypePtr()>;
using DataTypesDictionary = std::unordered_map<String, Creator>;
using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>;
using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>;
public:
DataTypePtr get(const String & full_name) const;
@ -40,11 +39,13 @@ public:
/// Register a simple data type, that have no parameters.
void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
// Register a domain - a refinement of existing type.
void registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Register a customized type family
void registerDataTypeCustom(const String & family_name, CreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Register a simple customized data type
void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
private:
static void setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain);
const Creator& findCreatorByName(const String & family_name) const;
private:
@ -53,9 +54,6 @@ private:
/// Case insensitive data types will be additionally added here with lowercased name.
DataTypesDictionary case_insensitive_data_types;
// All domains are owned by factory and shared amongst DataType instances.
std::vector<DataTypeDomainPtr> all_domains;
DataTypeFactory();
~DataTypeFactory() override;

View File

@ -6,7 +6,7 @@
#include <Formats/ProtobufWriter.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <IO/readFloatText.h>
#include <IO/readDecimalText.h>
#include <Parsers/IAST.h>
#include <Parsers/ASTLiteral.h>
#include <Interpreters/Context.h>
@ -52,9 +52,21 @@ void DataTypeDecimal<T>::serializeText(const IColumn & column, size_t row_num, W
}
template <typename T>
void DataTypeDecimal<T>::readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale)
bool DataTypeDecimal<T>::tryReadText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale)
{
UInt32 unread_scale = scale;
bool done = tryReadDecimalText(istr, x, precision, unread_scale);
x *= getScaleMultiplier(unread_scale);
return done;
}
template <typename T>
void DataTypeDecimal<T>::readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale, bool csv)
{
UInt32 unread_scale = scale;
if (csv)
readCSVDecimalText(istr, x, precision, unread_scale);
else
readDecimalText(istr, x, precision, unread_scale);
x *= getScaleMultiplier(unread_scale);
}
@ -67,6 +79,13 @@ void DataTypeDecimal<T>::deserializeText(IColumn & column, ReadBuffer & istr, co
static_cast<ColumnType &>(column).getData().push_back(x);
}
template <typename T>
void DataTypeDecimal<T>::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings &) const
{
T x;
readText(x, istr, true);
static_cast<ColumnType &>(column).getData().push_back(x);
}
template <typename T>
T DataTypeDecimal<T>::parseFromString(const String & str) const

View File

@ -1,4 +1,6 @@
#pragma once
#include <cmath>
#include <common/arithmeticOverflow.h>
#include <Common/typeid_cast.h>
#include <Columns/ColumnDecimal.h>
@ -91,6 +93,7 @@ public:
void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override;
void deserializeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override;
void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override;
void serializeBinary(const Field & field, WriteBuffer & ostr) const override;
void serializeBinary(const IColumn & column, size_t row_num, WriteBuffer & ostr) const override;
@ -175,8 +178,9 @@ public:
T parseFromString(const String & str) const;
void readText(T & x, ReadBuffer & istr) const { readText(x, istr, precision, scale); }
static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale);
void readText(T & x, ReadBuffer & istr, bool csv = false) const { readText(x, istr, precision, scale, csv); }
static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale, bool csv = false);
static bool tryReadText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale);
static T getScaleMultiplier(UInt32 scale);
private:
@ -318,7 +322,11 @@ convertToDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
using FromFieldType = typename FromDataType::FieldType;
if constexpr (std::is_floating_point_v<FromFieldType>)
{
if (!std::isfinite(value))
throw Exception("Decimal convert overflow. Cannot convert infinity or NaN to decimal", ErrorCodes::DECIMAL_OVERFLOW);
return value * ToDataType::getScaleMultiplier(scale);
}
else
{
if constexpr (std::is_same_v<FromFieldType, UInt64>)

View File

@ -9,7 +9,7 @@
#include <IO/WriteHelpers.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/IDataTypeDomain.h>
#include <DataTypes/DataTypeCustom.h>
#include <DataTypes/NestedUtils.h>
@ -23,8 +23,7 @@ namespace ErrorCodes
extern const int DATA_TYPE_CANNOT_BE_PROMOTED;
}
IDataType::IDataType()
: domain(nullptr)
IDataType::IDataType() : custom_name(nullptr), custom_text_serialization(nullptr)
{
}
@ -34,9 +33,9 @@ IDataType::~IDataType()
String IDataType::getName() const
{
if (domain)
if (custom_name)
{
return domain->getName();
return custom_name->getName();
}
else
{
@ -142,9 +141,9 @@ void IDataType::insertDefaultInto(IColumn & column) const
void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->serializeTextEscaped(column, row_num, ostr, settings);
custom_text_serialization->serializeTextEscaped(column, row_num, ostr, settings);
}
else
{
@ -154,9 +153,9 @@ void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, W
void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->deserializeTextEscaped(column, istr, settings);
custom_text_serialization->deserializeTextEscaped(column, istr, settings);
}
else
{
@ -166,9 +165,9 @@ void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, co
void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->serializeTextQuoted(column, row_num, ostr, settings);
custom_text_serialization->serializeTextQuoted(column, row_num, ostr, settings);
}
else
{
@ -178,9 +177,9 @@ void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, Wr
void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->deserializeTextQuoted(column, istr, settings);
custom_text_serialization->deserializeTextQuoted(column, istr, settings);
}
else
{
@ -190,9 +189,9 @@ void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, con
void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->serializeTextCSV(column, row_num, ostr, settings);
custom_text_serialization->serializeTextCSV(column, row_num, ostr, settings);
}
else
{
@ -202,9 +201,9 @@ void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, Write
void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->deserializeTextCSV(column, istr, settings);
custom_text_serialization->deserializeTextCSV(column, istr, settings);
}
else
{
@ -214,9 +213,9 @@ void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const
void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->serializeText(column, row_num, ostr, settings);
custom_text_serialization->serializeText(column, row_num, ostr, settings);
}
else
{
@ -226,9 +225,9 @@ void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuf
void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->serializeTextJSON(column, row_num, ostr, settings);
custom_text_serialization->serializeTextJSON(column, row_num, ostr, settings);
}
else
{
@ -238,9 +237,9 @@ void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, Writ
void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->deserializeTextJSON(column, istr, settings);
custom_text_serialization->deserializeTextJSON(column, istr, settings);
}
else
{
@ -250,9 +249,9 @@ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const
void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (custom_text_serialization)
{
domain->serializeTextXML(column, row_num, ostr, settings);
custom_text_serialization->serializeTextXML(column, row_num, ostr, settings);
}
else
{
@ -260,13 +259,14 @@ void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, Write
}
}
void IDataType::setDomain(const IDataTypeDomain* const new_domain) const
void IDataType::setCustomization(DataTypeCustomDescPtr custom_desc_) const
{
if (domain != nullptr)
{
throw Exception("Type " + getName() + " already has a domain.", ErrorCodes::LOGICAL_ERROR);
}
domain = new_domain;
/// replace only if not null
if (custom_desc_->name)
custom_name = std::move(custom_desc_->name);
if (custom_desc_->text_serialization)
custom_text_serialization = std::move(custom_desc_->text_serialization);
}
}

View File

@ -4,6 +4,7 @@
#include <Common/COW.h>
#include <boost/noncopyable.hpp>
#include <Core/Field.h>
#include <DataTypes/DataTypeCustom.h>
namespace DB
@ -12,7 +13,6 @@ namespace DB
class ReadBuffer;
class WriteBuffer;
class IDataTypeDomain;
class IDataType;
struct FormatSettings;
@ -459,18 +459,19 @@ public:
private:
friend class DataTypeFactory;
/** Sets domain on existing DataType, can be considered as second phase
* of construction explicitly done by DataTypeFactory.
* Will throw an exception if domain is already set.
/** Customize this DataType
*/
void setDomain(const IDataTypeDomain* newDomain) const;
void setCustomization(DataTypeCustomDescPtr custom_desc_) const;
private:
/** This is mutable to allow setting domain on `const IDataType` post construction,
* simplifying creation of domains for all types, without them even knowing
* of domain existence.
/** This is mutable to allow setting custom name and serialization on `const IDataType` post construction.
*/
mutable IDataTypeDomain const* domain;
mutable DataTypeCustomNamePtr custom_name;
mutable DataTypeCustomTextSerializationPtr custom_text_serialization;
public:
const IDataTypeCustomName * getCustomName() const { return custom_name.get(); }
const IDataTypeCustomTextSerialization * getCustomTextSerialization() const { return custom_text_serialization.get(); }
};
@ -573,6 +574,13 @@ inline bool isInteger(const T & data_type)
return which.isInt() || which.isUInt();
}
template <typename T>
inline bool isFloat(const T & data_type)
{
WhichDataType which(data_type);
return which.isFloat();
}
template <typename T>
inline bool isNumber(const T & data_type)
{

View File

@ -4,7 +4,7 @@
#include <unordered_set>
#include <Databases/DatabasesCommon.h>
#include <Databases/IDatabase.h>
#include <Storages/IStorage.h>
#include <Storages/IStorage_fwd.h>
namespace Poco

View File

@ -1,27 +1,28 @@
#include <iomanip>
#include <Poco/Event.h>
#include <Poco/DirectoryIterator.h>
#include <common/logger_useful.h>
#include <Databases/DatabaseOrdinary.h>
#include <Databases/DatabaseMemory.h>
#include <Databases/DatabasesCommon.h>
#include <Common/typeid_cast.h>
#include <Common/escapeForFileName.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/Stopwatch.h>
#include <Common/ThreadPool.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/parseQuery.h>
#include <Parsers/ParserCreateQuery.h>
#include <Interpreters/Context.h>
#include <Core/Settings.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <IO/WriteBufferFromFile.h>
#include <Databases/DatabaseMemory.h>
#include <Databases/DatabaseOrdinary.h>
#include <Databases/DatabasesCommon.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromFile.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/parseQuery.h>
#include <Storages/IStorage.h>
#include <Poco/DirectoryIterator.h>
#include <Poco/Event.h>
#include <Common/Stopwatch.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/ThreadPool.h>
#include <Common/escapeForFileName.h>
#include <Common/typeid_cast.h>
#include <common/logger_useful.h>
#include <ext/scope_guard.h>

View File

@ -1,14 +1,16 @@
#include <sstream>
#include <Databases/DatabasesCommon.h>
#include <Common/typeid_cast.h>
#include <Parsers/parseQuery.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/formatAST.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/formatAST.h>
#include <Parsers/parseQuery.h>
#include <Storages/IStorage.h>
#include <Storages/StorageFactory.h>
#include <Databases/DatabasesCommon.h>
#include <Common/typeid_cast.h>
#include <sstream>
namespace DB

View File

@ -2,7 +2,7 @@
#include <Core/Types.h>
#include <Parsers/IAST.h>
#include <Storages/IStorage.h>
#include <Storages/IStorage_fwd.h>
#include <Databases/IDatabase.h>

View File

@ -6,6 +6,7 @@
#include <Parsers/IAST_fwd.h>
#include <Storages/ColumnsDescription.h>
#include <Storages/IndicesDescription.h>
#include <Storages/IStorage_fwd.h>
#include <Poco/File.h>
#include <Common/ThreadPool.h>
#include <Common/escapeForFileName.h>
@ -20,9 +21,6 @@ namespace DB
class Context;
class IStorage;
using StoragePtr = std::shared_ptr<IStorage>;
struct Settings;

View File

@ -3,6 +3,7 @@
#include <Columns/ColumnsNumber.h>
#include <Common/ProfilingScopedRWLock.h>
#include <Common/typeid_cast.h>
#include <DataStreams/IBlockInputStream.h>
#include <ext/map.h>
#include <ext/range.h>
#include <ext/size.h>

View File

@ -20,6 +20,7 @@
#include "DictionaryStructure.h"
#include "IDictionary.h"
#include "IDictionarySource.h"
#include <DataStreams/IBlockInputStream.h>
namespace ProfileEvents

View File

@ -7,6 +7,7 @@
#include <Columns/ColumnString.h>
#include <Common/Arena.h>
#include <Common/HashTable/HashMap.h>
#include <Core/Block.h>
#include <common/StringRef.h>
#include <ext/range.h>
#include "DictionaryStructure.h"

View File

@ -1,6 +1,7 @@
#pragma once
#include "IDictionarySource.h"
#include <Core/Block.h>
#include <unordered_map>
#include <ext/singleton.h>

View File

@ -2,12 +2,10 @@
#include "DictionaryStructure.h"
#include "IDictionarySource.h"
#include <Core/Block.h>
namespace Poco
{
class Logger;
}
namespace Poco { class Logger; }
namespace DB

View File

@ -2,6 +2,7 @@
#include <Poco/Timestamp.h>
#include "IDictionarySource.h"
#include <Core/Block.h>
namespace DB

View File

@ -6,6 +6,7 @@
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h>
#include <Common/Arena.h>
#include <Core/Block.h>
#include <ext/range.h>
#include <ext/size.h>
#include "DictionaryStructure.h"

View File

@ -5,6 +5,7 @@
#include <variant>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h>
#include <Core/Block.h>
#include <Common/HashTable/HashMap.h>
#include <ext/range.h>
#include "DictionaryStructure.h"

View File

@ -1,27 +1,27 @@
#pragma once
#include <chrono>
#include <memory>
#include <Core/Field.h>
#include <Core/Names.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Interpreters/IExternalLoadable.h>
#include <Poco/Util/XMLConfiguration.h>
#include <Common/PODArray.h>
#include <common/StringRef.h>
#include "IDictionarySource.h"
#include <chrono>
#include <memory>
namespace DB
{
struct IDictionaryBase;
using DictionaryPtr = std::unique_ptr<IDictionaryBase>;
struct DictionaryStructure;
class ColumnString;
class IBlockInputStream;
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
struct IDictionaryBase : public IExternalLoadable
{
using Key = UInt64;

View File

@ -1,7 +1,9 @@
#pragma once
#include <Columns/IColumn.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <vector>
#include <DataStreams/IBlockInputStream.h>
namespace DB

View File

@ -1,6 +1,7 @@
#pragma once
#include <Common/config.h>
#include <Core/Block.h>
#if USE_POCO_MONGODB
# include "DictionaryStructure.h"

View File

@ -1,6 +1,7 @@
#pragma once
#include <Common/config.h>
#include <Core/Block.h>
#if USE_MYSQL
# include <common/LocalDateTime.h>

View File

@ -1,11 +1,12 @@
#pragma once
#include <DataStreams/IBlockStream_fwd.h>
#include <string>
namespace DB
{
class IBlockInputStream;
/// Using in MySQLDictionarySource and XDBCDictionarySource after processing invalidate_query.
std::string readInvalidateQuery(IBlockInputStream & block_input_stream);

View File

@ -64,23 +64,25 @@ void registerInputFormatRowBinary(FormatFactory & factory)
const Block & sample,
const Context &,
UInt64 max_block_size,
UInt64 rows_portion_size,
const FormatSettings & settings)
{
return std::make_shared<BlockInputStreamFromRowInputStream>(
std::make_shared<BinaryRowInputStream>(buf, sample, false, false),
sample, max_block_size, settings);
sample, max_block_size, rows_portion_size, settings);
});
factory.registerInputFormat("RowBinaryWithNamesAndTypes", [](
ReadBuffer & buf,
const Block & sample,
const Context &,
size_t max_block_size,
UInt64 max_block_size,
UInt64 rows_portion_size,
const FormatSettings & settings)
{
return std::make_shared<BlockInputStreamFromRowInputStream>(
std::make_shared<BinaryRowInputStream>(buf, sample, true, true),
sample, max_block_size, settings);
sample, max_block_size, rows_portion_size, settings);
});
}

View File

@ -27,8 +27,9 @@ BlockInputStreamFromRowInputStream::BlockInputStreamFromRowInputStream(
const RowInputStreamPtr & row_input_,
const Block & sample_,
UInt64 max_block_size_,
UInt64 rows_portion_size_,
const FormatSettings & settings)
: row_input(row_input_), sample(sample_), max_block_size(max_block_size_),
: row_input(row_input_), sample(sample_), max_block_size(max_block_size_), rows_portion_size(rows_portion_size_),
allow_errors_num(settings.input_allow_errors_num), allow_errors_ratio(settings.input_allow_errors_ratio)
{
}
@ -57,8 +58,15 @@ Block BlockInputStreamFromRowInputStream::readImpl()
try
{
for (size_t rows = 0; rows < max_block_size; ++rows)
for (size_t rows = 0, batch = 0; rows < max_block_size; ++rows, ++batch)
{
if (rows_portion_size && batch == rows_portion_size)
{
batch = 0;
if (!checkTimeLimit() || isCancelled())
break;
}
try
{
++total_rows;

View File

@ -17,11 +17,13 @@ namespace DB
class BlockInputStreamFromRowInputStream : public IBlockInputStream
{
public:
/** sample_ - block with zero rows, that structure describes how to interpret values */
/// |sample| is a block with zero rows, that structure describes how to interpret values
/// |rows_portion_size| is a number of rows to read before break and check limits
BlockInputStreamFromRowInputStream(
const RowInputStreamPtr & row_input_,
const Block & sample_,
UInt64 max_block_size_,
UInt64 rows_portion_size_,
const FormatSettings & settings);
void readPrefix() override { row_input->readPrefix(); }
@ -42,6 +44,7 @@ private:
RowInputStreamPtr row_input;
Block sample;
UInt64 max_block_size;
UInt64 rows_portion_size;
BlockMissingValues block_missing_values;
UInt64 allow_errors_num;
@ -50,5 +53,4 @@ private:
size_t total_rows = 0;
size_t num_errors = 0;
};
}

View File

@ -478,11 +478,12 @@ void registerInputFormatCSV(FormatFactory & factory)
const Block & sample,
const Context &,
UInt64 max_block_size,
UInt64 rows_portion_size,
const FormatSettings & settings)
{
return std::make_shared<BlockInputStreamFromRowInputStream>(
std::make_shared<CSVRowInputStream>(buf, sample, with_names, settings),
sample, max_block_size, settings);
sample, max_block_size, rows_portion_size, settings);
});
}
}

View File

@ -302,12 +302,18 @@ void registerInputFormatCapnProto(FormatFactory & factory)
{
factory.registerInputFormat(
"CapnProto",
[](ReadBuffer & buf, const Block & sample, const Context & context, UInt64 max_block_size, const FormatSettings & settings)
[](ReadBuffer & buf,
const Block & sample,
const Context & context,
UInt64 max_block_size,
UInt64 rows_portion_size,
const FormatSettings & settings)
{
return std::make_shared<BlockInputStreamFromRowInputStream>(
std::make_shared<CapnProtoRowInputStream>(buf, sample, FormatSchemaInfo(context, "CapnProto")),
sample,
max_block_size,
rows_portion_size,
settings);
});
}

View File

@ -27,7 +27,7 @@ const FormatFactory::Creators & FormatFactory::getCreators(const String & name)
}
BlockInputStreamPtr FormatFactory::getInput(const String & name, ReadBuffer & buf, const Block & sample, const Context & context, UInt64 max_block_size) const
BlockInputStreamPtr FormatFactory::getInput(const String & name, ReadBuffer & buf, const Block & sample, const Context & context, UInt64 max_block_size, UInt64 rows_portion_size) const
{
const auto & input_getter = getCreators(name).first;
if (!input_getter)
@ -47,7 +47,7 @@ BlockInputStreamPtr FormatFactory::getInput(const String & name, ReadBuffer & bu
format_settings.input_allow_errors_num = settings.input_format_allow_errors_num;
format_settings.input_allow_errors_ratio = settings.input_format_allow_errors_ratio;
return input_getter(buf, sample, context, max_block_size, format_settings);
return input_getter(buf, sample, context, max_block_size, rows_portion_size, format_settings);
}

View File

@ -1,10 +1,12 @@
#pragma once
#include <memory>
#include <functional>
#include <unordered_map>
#include <ext/singleton.h>
#include <Core/Types.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <ext/singleton.h>
#include <functional>
#include <memory>
#include <unordered_map>
namespace DB
@ -17,13 +19,6 @@ struct FormatSettings;
class ReadBuffer;
class WriteBuffer;
class IBlockInputStream;
class IBlockOutputStream;
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
using BlockOutputStreamPtr = std::shared_ptr<IBlockOutputStream>;
/** Allows to create an IBlockInputStream or IBlockOutputStream by the name of the format.
* Note: format and compression are independent things.
*/
@ -35,6 +30,7 @@ private:
const Block & sample,
const Context & context,
UInt64 max_block_size,
UInt64 rows_portion_size,
const FormatSettings & settings)>;
using OutputCreator = std::function<BlockOutputStreamPtr(
@ -49,7 +45,7 @@ private:
public:
BlockInputStreamPtr getInput(const String & name, ReadBuffer & buf,
const Block & sample, const Context & context, UInt64 max_block_size) const;
const Block & sample, const Context & context, UInt64 max_block_size, UInt64 rows_portion_size = 0) const;
BlockOutputStreamPtr getOutput(const String & name, WriteBuffer & buf,
const Block & sample, const Context & context) const;

View File

@ -259,11 +259,12 @@ void registerInputFormatJSONEachRow(FormatFactory & factory)
const Block & sample,
const Context &,
UInt64 max_block_size,
UInt64 rows_portion_size,
const FormatSettings & settings)
{
return std::make_shared<BlockInputStreamFromRowInputStream>(
std::make_shared<JSONEachRowRowInputStream>(buf, sample, settings),
sample, max_block_size, settings);
sample, max_block_size, rows_portion_size, settings);
});
}

Some files were not shown because too many files have changed in this diff Show More