mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-12 17:32:32 +00:00
Merge branch 'master' into table-constraints
This commit is contained in:
commit
3c905e24e2
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -79,3 +79,6 @@
|
||||
[submodule "contrib/hyperscan"]
|
||||
path = contrib/hyperscan
|
||||
url = https://github.com/ClickHouse-Extras/hyperscan.git
|
||||
[submodule "contrib/simdjson"]
|
||||
path = contrib/simdjson
|
||||
url = https://github.com/lemire/simdjson.git
|
||||
|
@ -1,6 +1,15 @@
|
||||
project(ClickHouse)
|
||||
cmake_minimum_required(VERSION 3.3)
|
||||
cmake_policy(SET CMP0023 NEW)
|
||||
|
||||
foreach(policy
|
||||
CMP0023
|
||||
CMP0074 # CMake 3.12
|
||||
)
|
||||
if(POLICY ${policy})
|
||||
cmake_policy(SET ${policy} NEW)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/")
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS 1) # Write compile_commands.json
|
||||
set(CMAKE_LINK_DEPENDS_NO_SHARED 1) # Do not relink all depended targets on .so
|
||||
@ -301,6 +310,7 @@ include (cmake/find_rt.cmake)
|
||||
include (cmake/find_execinfo.cmake)
|
||||
include (cmake/find_readline_edit.cmake)
|
||||
include (cmake/find_re2.cmake)
|
||||
include (cmake/find_libgsasl.cmake)
|
||||
include (cmake/find_rdkafka.cmake)
|
||||
include (cmake/find_capnp.cmake)
|
||||
include (cmake/find_llvm.cmake)
|
||||
@ -308,7 +318,6 @@ include (cmake/find_cpuid.cmake) # Freebsd, bundled
|
||||
if (NOT USE_CPUID)
|
||||
include (cmake/find_cpuinfo.cmake) # Debian
|
||||
endif()
|
||||
include (cmake/find_libgsasl.cmake)
|
||||
include (cmake/find_libxml2.cmake)
|
||||
include (cmake/find_brotli.cmake)
|
||||
include (cmake/find_protobuf.cmake)
|
||||
@ -318,6 +327,7 @@ include (cmake/find_consistent-hashing.cmake)
|
||||
include (cmake/find_base64.cmake)
|
||||
include (cmake/find_hyperscan.cmake)
|
||||
include (cmake/find_lfalloc.cmake)
|
||||
include (cmake/find_simdjson.cmake)
|
||||
find_contrib_lib(cityhash)
|
||||
find_contrib_lib(farmhash)
|
||||
find_contrib_lib(metrohash)
|
||||
|
@ -12,8 +12,8 @@ ClickHouse is an open-source column-oriented database management system that all
|
||||
* You can also [fill this form](https://forms.yandex.com/surveys/meet-yandex-clickhouse-team/) to meet Yandex ClickHouse team in person.
|
||||
|
||||
## Upcoming Events
|
||||
* [ClickHouse Community Meetup in Limassol](https://www.facebook.com/events/386638262181785/) on May 7.
|
||||
* ClickHouse at [Percona Live 2019](https://www.percona.com/live/19/other-open-source-databases-track) in Austin on May 28-30.
|
||||
* [ClickHouse Community Meetup in San Francisco](https://www.meetup.com/San-Francisco-Bay-Area-ClickHouse-Meetup/events/261110652/) on June 4.
|
||||
* [ClickHouse Community Meetup in Beijing](https://www.huodongxing.com/event/2483759276200) on June 8.
|
||||
* [ClickHouse Community Meetup in Shenzhen](https://www.huodongxing.com/event/3483759917300) on October 20.
|
||||
* [ClickHouse Community Meetup in Shanghai](https://www.huodongxing.com/event/4483760336000) on October 27.
|
||||
|
@ -1,9 +1,12 @@
|
||||
option (USE_INTERNAL_BOOST_LIBRARY "Set to FALSE to use system boost library instead of bundled" ${NOT_UNBUNDLED})
|
||||
|
||||
# Test random file existing in all package variants
|
||||
if (USE_INTERNAL_BOOST_LIBRARY AND NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/boost/libs/system/src/error_code.cpp")
|
||||
message (WARNING "submodules in contrib/boost is missing. to fix try run: \n git submodule update --init --recursive")
|
||||
set (USE_INTERNAL_BOOST_LIBRARY 0)
|
||||
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/boost/libs/system/src/error_code.cpp")
|
||||
if(USE_INTERNAL_BOOST_LIBRARY)
|
||||
message(WARNING "submodules in contrib/boost is missing. to fix try run: \n git submodule update --init --recursive")
|
||||
endif()
|
||||
set (USE_INTERNAL_BOOST_LIBRARY 0)
|
||||
set (MISSING_INTERNAL_BOOST_LIBRARY 1)
|
||||
endif ()
|
||||
|
||||
if (NOT USE_INTERNAL_BOOST_LIBRARY)
|
||||
@ -21,10 +24,9 @@ if (NOT USE_INTERNAL_BOOST_LIBRARY)
|
||||
set (Boost_INCLUDE_DIRS "")
|
||||
set (Boost_SYSTEM_LIBRARY "")
|
||||
endif ()
|
||||
|
||||
endif ()
|
||||
|
||||
if (NOT Boost_SYSTEM_LIBRARY)
|
||||
if (NOT Boost_SYSTEM_LIBRARY AND NOT MISSING_INTERNAL_BOOST_LIBRARY)
|
||||
set (USE_INTERNAL_BOOST_LIBRARY 1)
|
||||
set (Boost_SYSTEM_LIBRARY boost_system_internal)
|
||||
set (Boost_PROGRAM_OPTIONS_LIBRARY boost_program_options_internal)
|
||||
@ -44,7 +46,6 @@ if (NOT Boost_SYSTEM_LIBRARY)
|
||||
|
||||
# For packaged version:
|
||||
list (APPEND Boost_INCLUDE_DIRS "${ClickHouse_SOURCE_DIR}/contrib/boost")
|
||||
|
||||
endif ()
|
||||
|
||||
message (STATUS "Using Boost: ${Boost_INCLUDE_DIRS} : ${Boost_PROGRAM_OPTIONS_LIBRARY},${Boost_SYSTEM_LIBRARY},${Boost_FILESYSTEM_LIBRARY},${Boost_REGEX_LIBRARY}")
|
||||
|
@ -1,6 +1,9 @@
|
||||
option(ENABLE_ICU "Enable ICU" ON)
|
||||
|
||||
if(ENABLE_ICU)
|
||||
if (APPLE)
|
||||
set(ICU_ROOT "/usr/local/opt/icu4c" CACHE STRING "")
|
||||
endif()
|
||||
find_package(ICU COMPONENTS i18n uc data) # TODO: remove Modules/FindICU.cmake after cmake 3.7
|
||||
#set (ICU_LIBRARIES ${ICU_I18N_LIBRARY} ${ICU_UC_LIBRARY} ${ICU_DATA_LIBRARY} CACHE STRING "")
|
||||
if(ICU_FOUND)
|
||||
|
@ -1,4 +1,4 @@
|
||||
if (NOT SANITIZE AND NOT ARCH_ARM AND NOT ARCH_32 AND NOT ARCH_PPC64LE AND NOT OS_FREEBSD)
|
||||
if (NOT SANITIZE AND NOT ARCH_ARM AND NOT ARCH_32 AND NOT ARCH_PPC64LE AND NOT OS_FREEBSD AND NOT APPLE)
|
||||
option (ENABLE_LFALLOC "Set to FALSE to use system libgsasl library instead of bundled" ${NOT_UNBUNDLED})
|
||||
endif ()
|
||||
|
||||
|
@ -22,4 +22,8 @@ elseif (NOT MISSING_INTERNAL_LIBGSASL_LIBRARY AND NOT APPLE AND NOT ARCH_32)
|
||||
set (LIBGSASL_LIBRARY libgsasl)
|
||||
endif ()
|
||||
|
||||
message (STATUS "Using libgsasl: ${LIBGSASL_INCLUDE_DIR} : ${LIBGSASL_LIBRARY}")
|
||||
if(LIBGSASL_LIBRARY AND LIBGSASL_INCLUDE_DIR)
|
||||
set (USE_LIBGSASL 1)
|
||||
endif()
|
||||
|
||||
message (STATUS "Using libgsasl=${USE_LIBGSASL}: ${LIBGSASL_INCLUDE_DIR} : ${LIBGSASL_LIBRARY}")
|
||||
|
@ -10,7 +10,7 @@ endif ()
|
||||
|
||||
if (ENABLE_RDKAFKA)
|
||||
|
||||
if (OS_LINUX AND NOT ARCH_ARM)
|
||||
if (OS_LINUX AND NOT ARCH_ARM AND USE_LIBGSASL)
|
||||
option (USE_INTERNAL_RDKAFKA_LIBRARY "Set to FALSE to use system librdkafka instead of the bundled" ${NOT_UNBUNDLED})
|
||||
endif ()
|
||||
|
||||
|
@ -1,5 +1,13 @@
|
||||
option (USE_INTERNAL_RE2_LIBRARY "Set to FALSE to use system re2 library instead of bundled [slower]" ${NOT_UNBUNDLED})
|
||||
|
||||
if(NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/re2/CMakeLists.txt")
|
||||
if(USE_INTERNAL_RE2_LIBRARY)
|
||||
message(WARNING "submodule contrib/re2 is missing. to fix try run: \n git submodule update --init --recursive")
|
||||
endif()
|
||||
set(USE_INTERNAL_RE2_LIBRARY 0)
|
||||
set(MISSING_INTERNAL_RE2_LIBRARY 1)
|
||||
endif()
|
||||
|
||||
if (NOT USE_INTERNAL_RE2_LIBRARY)
|
||||
find_library (RE2_LIBRARY re2)
|
||||
find_path (RE2_INCLUDE_DIR NAMES re2/re2.h PATHS ${RE2_INCLUDE_PATHS})
|
||||
|
14
cmake/find_simdjson.cmake
Normal file
14
cmake/find_simdjson.cmake
Normal file
@ -0,0 +1,14 @@
|
||||
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/simdjson/include/simdjson/jsonparser.h")
|
||||
message (WARNING "submodule contrib/simdjson is missing. to fix try run: \n git submodule update --init --recursive")
|
||||
return()
|
||||
endif ()
|
||||
|
||||
if (NOT HAVE_AVX2)
|
||||
message (WARNING "submodule contrib/simdjson requires AVX2 support")
|
||||
return()
|
||||
endif ()
|
||||
|
||||
option (USE_SIMDJSON "Use simdjson" ON)
|
||||
set (SIMDJSON_LIBRARY "simdjson")
|
||||
|
||||
message(STATUS "Using simdjson=${USE_SIMDJSON}: ${SIMDJSON_LIBRARY}")
|
@ -2,20 +2,28 @@ if (NOT OS_FREEBSD AND NOT ARCH_32)
|
||||
option (USE_INTERNAL_ZLIB_LIBRARY "Set to FALSE to use system zlib library instead of bundled" ${NOT_UNBUNDLED})
|
||||
endif ()
|
||||
|
||||
if (NOT MSVC)
|
||||
set (INTERNAL_ZLIB_NAME "zlib-ng" CACHE INTERNAL "")
|
||||
else ()
|
||||
set (INTERNAL_ZLIB_NAME "zlib" CACHE INTERNAL "")
|
||||
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}")
|
||||
message (WARNING "Will use standard zlib, please clone manually:\n git clone https://github.com/madler/zlib.git ${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if(NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}/zlib.h")
|
||||
if(USE_INTERNAL_ZLIB_LIBRARY)
|
||||
message(WARNING "submodule contrib/${INTERNAL_ZLIB_NAME} is missing. to fix try run: \n git submodule update --init --recursive")
|
||||
endif()
|
||||
set(USE_INTERNAL_ZLIB_LIBRARY 0)
|
||||
set(MISSING_INTERNAL_ZLIB_LIBRARY 1)
|
||||
endif()
|
||||
|
||||
if (NOT USE_INTERNAL_ZLIB_LIBRARY)
|
||||
find_package (ZLIB)
|
||||
endif ()
|
||||
|
||||
if (NOT ZLIB_FOUND)
|
||||
if (NOT MSVC)
|
||||
set (INTERNAL_ZLIB_NAME "zlib-ng" CACHE INTERNAL "")
|
||||
else ()
|
||||
set (INTERNAL_ZLIB_NAME "zlib" CACHE INTERNAL "")
|
||||
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}")
|
||||
message (WARNING "Will use standard zlib, please clone manually:\n git clone https://github.com/madler/zlib.git ${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (NOT ZLIB_FOUND AND NOT MISSING_INTERNAL_ZLIB_LIBRARY)
|
||||
set (USE_INTERNAL_ZLIB_LIBRARY 1)
|
||||
set (ZLIB_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/${INTERNAL_ZLIB_NAME}" "${ClickHouse_BINARY_DIR}/contrib/${INTERNAL_ZLIB_NAME}" CACHE INTERNAL "") # generated zconf.h
|
||||
set (ZLIB_INCLUDE_DIRS ${ZLIB_INCLUDE_DIR}) # for poco
|
||||
|
@ -1,9 +1,12 @@
|
||||
option (USE_INTERNAL_ZSTD_LIBRARY "Set to FALSE to use system zstd library instead of bundled" ${NOT_UNBUNDLED})
|
||||
|
||||
if (USE_INTERNAL_ZSTD_LIBRARY AND NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/zstd/lib/zstd.h")
|
||||
message (WARNING "submodule contrib/zstd is missing. to fix try run: \n git submodule update --init --recursive")
|
||||
set (USE_INTERNAL_ZSTD_LIBRARY 0)
|
||||
endif ()
|
||||
if(NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/zstd/lib/zstd.h")
|
||||
if(USE_INTERNAL_ZSTD_LIBRARY)
|
||||
message(WARNING "submodule contrib/zstd is missing. to fix try run: \n git submodule update --init --recursive")
|
||||
endif()
|
||||
set(USE_INTERNAL_ZSTD_LIBRARY 0)
|
||||
set(MISSING_INTERNAL_ZSTD_LIBRARY 1)
|
||||
endif()
|
||||
|
||||
if (NOT USE_INTERNAL_ZSTD_LIBRARY)
|
||||
find_library (ZSTD_LIBRARY zstd)
|
||||
@ -11,7 +14,7 @@ if (NOT USE_INTERNAL_ZSTD_LIBRARY)
|
||||
endif ()
|
||||
|
||||
if (ZSTD_LIBRARY AND ZSTD_INCLUDE_DIR)
|
||||
else ()
|
||||
elseif (NOT MISSING_INTERNAL_ZSTD_LIBRARY)
|
||||
set (USE_INTERNAL_ZSTD_LIBRARY 1)
|
||||
set (ZSTD_LIBRARY zstd)
|
||||
set (ZSTD_INCLUDE_DIR ${ClickHouse_SOURCE_DIR}/contrib/zstd/lib)
|
||||
|
6
contrib/CMakeLists.txt
vendored
6
contrib/CMakeLists.txt
vendored
@ -227,7 +227,7 @@ if (USE_INTERNAL_POCO_LIBRARY)
|
||||
set (ENABLE_TESTS 0)
|
||||
set (POCO_ENABLE_TESTS 0)
|
||||
set (CMAKE_DISABLE_FIND_PACKAGE_ZLIB 1)
|
||||
if (MSVC)
|
||||
if (MSVC OR NOT USE_POCO_DATAODBC)
|
||||
set (ENABLE_DATA_ODBC 0 CACHE INTERNAL "") # TODO (build fail)
|
||||
endif ()
|
||||
add_subdirectory (poco)
|
||||
@ -313,3 +313,7 @@ endif()
|
||||
if (USE_INTERNAL_HYPERSCAN_LIBRARY)
|
||||
add_subdirectory (hyperscan)
|
||||
endif()
|
||||
|
||||
if (USE_SIMDJSON)
|
||||
add_subdirectory (simdjson-cmake)
|
||||
endif()
|
||||
|
2
contrib/boost
vendored
2
contrib/boost
vendored
@ -1 +1 @@
|
||||
Subproject commit 471ea208abb92a5cba7d3a08a819bb728f27e95f
|
||||
Subproject commit 79bf85ea99c05ba4fb6959474d4464ab126f8973
|
@ -33,6 +33,7 @@ set(SRCS
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_roundrobin_assignor.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_sasl.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_sasl_plain.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_sasl_scram.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_subscription.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_timer.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_topic.c
|
||||
@ -58,7 +59,7 @@ add_library(rdkafka ${SRCS})
|
||||
target_include_directories(rdkafka SYSTEM PUBLIC include)
|
||||
target_include_directories(rdkafka SYSTEM PUBLIC ${RDKAFKA_SOURCE_DIR}) # Because weird logic with "include_next" is used.
|
||||
target_include_directories(rdkafka SYSTEM PRIVATE ${ZSTD_INCLUDE_DIR}/common) # Because wrong path to "zstd_errors.h" is used.
|
||||
target_link_libraries(rdkafka PUBLIC ${ZLIB_LIBRARIES} ${ZSTD_LIBRARY} ${LZ4_LIBRARY})
|
||||
target_link_libraries(rdkafka PUBLIC ${ZLIB_LIBRARIES} ${ZSTD_LIBRARY} ${LZ4_LIBRARY} ${LIBGSASL_LIBRARY})
|
||||
if(OPENSSL_SSL_LIBRARY AND OPENSSL_CRYPTO_LIBRARY)
|
||||
target_link_libraries(rdkafka PUBLIC ${OPENSSL_SSL_LIBRARY} ${OPENSSL_CRYPTO_LIBRARY})
|
||||
endif()
|
||||
|
@ -12,7 +12,7 @@
|
||||
#define ENABLE_SHAREDPTR_DEBUG 0
|
||||
#define ENABLE_LZ4_EXT 1
|
||||
#define ENABLE_SSL 1
|
||||
//#define ENABLE_SASL 1
|
||||
#define ENABLE_SASL 1
|
||||
#define MKL_APP_NAME "librdkafka"
|
||||
#define MKL_APP_DESC_ONELINE "The Apache Kafka C/C++ library"
|
||||
// distro
|
||||
@ -62,7 +62,7 @@
|
||||
// libssl
|
||||
#define WITH_SSL 1
|
||||
// WITH_SASL_SCRAM
|
||||
//#define WITH_SASL_SCRAM 1
|
||||
#define WITH_SASL_SCRAM 1
|
||||
// crc32chw
|
||||
#if !defined(__PPC__)
|
||||
#define WITH_CRC32C_HW 1
|
||||
|
1
contrib/simdjson
vendored
Submodule
1
contrib/simdjson
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 681cd3369860f4eada49a387cbff93030f759c95
|
18
contrib/simdjson-cmake/CMakeLists.txt
Normal file
18
contrib/simdjson-cmake/CMakeLists.txt
Normal file
@ -0,0 +1,18 @@
|
||||
if (NOT HAVE_AVX2)
|
||||
message (FATAL_ERROR "No AVX2 support")
|
||||
endif ()
|
||||
set(SIMDJSON_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/simdjson/include")
|
||||
set(SIMDJSON_SRC_DIR "${SIMDJSON_INCLUDE_DIR}/../src")
|
||||
set(SIMDJSON_SRC
|
||||
${SIMDJSON_SRC_DIR}/jsonioutil.cpp
|
||||
${SIMDJSON_SRC_DIR}/jsonminifier.cpp
|
||||
${SIMDJSON_SRC_DIR}/jsonparser.cpp
|
||||
${SIMDJSON_SRC_DIR}/stage1_find_marks.cpp
|
||||
${SIMDJSON_SRC_DIR}/stage2_build_tape.cpp
|
||||
${SIMDJSON_SRC_DIR}/parsedjson.cpp
|
||||
${SIMDJSON_SRC_DIR}/parsedjsoniterator.cpp
|
||||
)
|
||||
|
||||
add_library(${SIMDJSON_LIBRARY} ${SIMDJSON_SRC})
|
||||
target_include_directories(${SIMDJSON_LIBRARY} PUBLIC "${SIMDJSON_INCLUDE_DIR}")
|
||||
target_compile_options(${SIMDJSON_LIBRARY} PRIVATE -mavx2 -mbmi -mbmi2 -mpclmul)
|
@ -189,8 +189,17 @@ target_link_libraries (clickhouse_common_io
|
||||
${Poco_Net_LIBRARY}
|
||||
${Poco_Util_LIBRARY}
|
||||
${Poco_Foundation_LIBRARY}
|
||||
${RE2_LIBRARY}
|
||||
${RE2_ST_LIBRARY}
|
||||
)
|
||||
|
||||
if(RE2_LIBRARY)
|
||||
target_link_libraries(clickhouse_common_io PUBLIC ${RE2_LIBRARY})
|
||||
endif()
|
||||
if(RE2_ST_LIBRARY)
|
||||
target_link_libraries(clickhouse_common_io PUBLIC ${RE2_ST_LIBRARY})
|
||||
endif()
|
||||
|
||||
target_link_libraries(clickhouse_common_io
|
||||
PUBLIC
|
||||
${CITYHASH_LIBRARIES}
|
||||
PRIVATE
|
||||
${ZLIB_LIBRARIES}
|
||||
@ -208,7 +217,9 @@ target_link_libraries (clickhouse_common_io
|
||||
)
|
||||
|
||||
|
||||
target_include_directories(clickhouse_common_io SYSTEM BEFORE PUBLIC ${RE2_INCLUDE_DIR})
|
||||
if(RE2_INCLUDE_DIR)
|
||||
target_include_directories(clickhouse_common_io SYSTEM BEFORE PUBLIC ${RE2_INCLUDE_DIR})
|
||||
endif()
|
||||
|
||||
if (USE_LFALLOC)
|
||||
target_include_directories (clickhouse_common_io SYSTEM BEFORE PUBLIC ${LFALLOC_INCLUDE_DIR})
|
||||
|
@ -209,6 +209,9 @@ else ()
|
||||
install (FILES ${CMAKE_CURRENT_BINARY_DIR}/clickhouse-obfuscator DESTINATION ${CMAKE_INSTALL_BINDIR} COMPONENT clickhouse)
|
||||
list(APPEND CLICKHOUSE_BUNDLE clickhouse-obfuscator)
|
||||
endif ()
|
||||
if(ENABLE_CLICKHOUSE_ODBC_BRIDGE)
|
||||
list(APPEND CLICKHOUSE_BUNDLE clickhouse-odbc-bridge)
|
||||
endif()
|
||||
|
||||
# install always because depian package want this files:
|
||||
add_custom_target (clickhouse-clang ALL COMMAND ${CMAKE_COMMAND} -E create_symlink clickhouse clickhouse-clang DEPENDS clickhouse)
|
||||
|
@ -1,5 +1,5 @@
|
||||
set(CLICKHOUSE_COPIER_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/ClusterCopier.cpp)
|
||||
set(CLICKHOUSE_COPIER_LINK PRIVATE clickhouse_functions clickhouse_table_functions clickhouse_aggregate_functions PUBLIC daemon)
|
||||
set(CLICKHOUSE_COPIER_LINK PRIVATE clickhouse_functions clickhouse_table_functions clickhouse_aggregate_functions clickhouse_dictionaries PUBLIC daemon)
|
||||
set(CLICKHOUSE_COPIER_INCLUDE SYSTEM PRIVATE ${PCG_RANDOM_INCLUDE_DIR})
|
||||
|
||||
clickhouse_program_add(copier)
|
||||
|
@ -63,6 +63,7 @@
|
||||
#include <AggregateFunctions/registerAggregateFunctions.h>
|
||||
#include <Storages/registerStorages.h>
|
||||
#include <Storages/StorageDistributed.h>
|
||||
#include <Dictionaries/registerDictionaries.h>
|
||||
#include <Databases/DatabaseMemory.h>
|
||||
#include <Common/StatusFile.h>
|
||||
|
||||
@ -2169,6 +2170,7 @@ void ClusterCopierApp::mainImpl()
|
||||
registerAggregateFunctions();
|
||||
registerTableFunctions();
|
||||
registerStorages();
|
||||
registerDictionaries();
|
||||
|
||||
static const std::string default_database = "_local";
|
||||
context->addDatabase(default_database, std::make_shared<DatabaseMemory>(default_database));
|
||||
|
@ -912,8 +912,8 @@ public:
|
||||
size_t columns = header.columns();
|
||||
models.reserve(columns);
|
||||
|
||||
for (size_t i = 0; i < columns; ++i)
|
||||
models.emplace_back(factory.get(*header.getByPosition(i).type, hash(seed, i), markov_model_params));
|
||||
for (const auto & elem : header)
|
||||
models.emplace_back(factory.get(*elem.type, hash(seed, elem.name), markov_model_params));
|
||||
}
|
||||
|
||||
void train(const Columns & columns)
|
||||
@ -954,7 +954,7 @@ try
|
||||
("structure,S", po::value<std::string>(), "structure of the initial table (list of column and type names)")
|
||||
("input-format", po::value<std::string>(), "input format of the initial table data")
|
||||
("output-format", po::value<std::string>(), "default output format")
|
||||
("seed", po::value<std::string>(), "seed (arbitrary string), must be random string with at least 10 bytes length")
|
||||
("seed", po::value<std::string>(), "seed (arbitrary string), must be random string with at least 10 bytes length; note that a seed for each column is derived from this seed and a column name: you can obfuscate data for different tables and as long as you use identical seed and identical column names, the data for corresponding non-text columns for different tables will be transformed in the same way, so the data for different tables can be JOINed after obfuscation")
|
||||
("limit", po::value<UInt64>(), "if specified - stop after generating that number of rows")
|
||||
("silent", po::value<bool>()->default_value(false), "don't print information messages to stderr")
|
||||
("order", po::value<UInt64>()->default_value(5), "order of markov model to generate strings")
|
||||
|
@ -370,7 +370,7 @@ try
|
||||
Poco::Logger * log = &Poco::Logger::get("PerformanceTestSuite");
|
||||
if (options.count("help"))
|
||||
{
|
||||
std::cout << "Usage: " << argv[0] << " [options] [test_file ...] [tests_folder]\n";
|
||||
std::cout << "Usage: " << argv[0] << " [options]\n";
|
||||
std::cout << desc << "\n";
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "TestStats.h"
|
||||
#include <algorithm>
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -92,11 +93,10 @@ void TestStats::update_average_speed(
|
||||
avg_speed_value /= number_of_info_batches;
|
||||
|
||||
if (avg_speed_first == 0)
|
||||
{
|
||||
avg_speed_first = avg_speed_value;
|
||||
}
|
||||
|
||||
if (std::abs(avg_speed_value - avg_speed_first) >= precision)
|
||||
auto [min, max] = std::minmax(avg_speed_value, avg_speed_first);
|
||||
if (1 - min / max >= precision)
|
||||
{
|
||||
avg_speed_first = avg_speed_value;
|
||||
avg_speed_watch.restart();
|
||||
|
@ -40,11 +40,11 @@ struct TestStats
|
||||
|
||||
double avg_rows_speed_value = 0;
|
||||
double avg_rows_speed_first = 0;
|
||||
static inline double avg_rows_speed_precision = 0.001;
|
||||
static inline double avg_rows_speed_precision = 0.005;
|
||||
|
||||
double avg_bytes_speed_value = 0;
|
||||
double avg_bytes_speed_first = 0;
|
||||
static inline double avg_bytes_speed_precision = 0.001;
|
||||
static inline double avg_bytes_speed_precision = 0.005;
|
||||
|
||||
size_t number_of_rows_speed_info_batches = 0;
|
||||
size_t number_of_bytes_speed_info_batches = 0;
|
||||
|
@ -79,6 +79,7 @@ namespace ErrorCodes
|
||||
extern const int SYSTEM_ERROR;
|
||||
extern const int FAILED_TO_GETPWUID;
|
||||
extern const int MISMATCHING_USERS_FOR_PROCESS_AND_DATA;
|
||||
extern const int NETWORK_ERROR;
|
||||
}
|
||||
|
||||
|
||||
@ -587,12 +588,12 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
return socket_address;
|
||||
};
|
||||
|
||||
auto socket_bind_listen = [&](auto & socket, const std::string & host, UInt16 port, bool secure = 0)
|
||||
auto socket_bind_listen = [&](auto & socket, const std::string & host, UInt16 port, [[maybe_unused]] bool secure = 0)
|
||||
{
|
||||
auto address = make_socket_address(host, port);
|
||||
#if !defined(POCO_CLICKHOUSE_PATCH) || POCO_VERSION <= 0x02000000 // TODO: fill correct version
|
||||
#if !defined(POCO_CLICKHOUSE_PATCH) || POCO_VERSION < 0x01090100
|
||||
if (secure)
|
||||
/// Bug in old poco, listen() after bind() with reusePort param will fail because have no implementation in SecureServerSocketImpl
|
||||
/// Bug in old (<1.9.1) poco, listen() after bind() with reusePort param will fail because have no implementation in SecureServerSocketImpl
|
||||
/// https://github.com/pocoproject/poco/pull/2257
|
||||
socket.bind(address, /* reuseAddress = */ true);
|
||||
else
|
||||
@ -611,13 +612,15 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
for (const auto & listen_host : listen_hosts)
|
||||
{
|
||||
/// For testing purposes, user may omit tcp_port or http_port or https_port in configuration file.
|
||||
uint16_t listen_port = 0;
|
||||
try
|
||||
{
|
||||
/// HTTP
|
||||
if (config().has("http_port"))
|
||||
{
|
||||
Poco::Net::ServerSocket socket;
|
||||
auto address = socket_bind_listen(socket, listen_host, config().getInt("http_port"));
|
||||
listen_port = config().getInt("http_port");
|
||||
auto address = socket_bind_listen(socket, listen_host, listen_port);
|
||||
socket.setReceiveTimeout(settings.http_receive_timeout);
|
||||
socket.setSendTimeout(settings.http_send_timeout);
|
||||
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
|
||||
@ -634,7 +637,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
{
|
||||
#if USE_POCO_NETSSL
|
||||
Poco::Net::SecureServerSocket socket;
|
||||
auto address = socket_bind_listen(socket, listen_host, config().getInt("https_port"), /* secure = */ true);
|
||||
listen_port = config().getInt("https_port");
|
||||
auto address = socket_bind_listen(socket, listen_host, listen_port, /* secure = */ true);
|
||||
socket.setReceiveTimeout(settings.http_receive_timeout);
|
||||
socket.setSendTimeout(settings.http_send_timeout);
|
||||
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
|
||||
@ -654,7 +658,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
if (config().has("tcp_port"))
|
||||
{
|
||||
Poco::Net::ServerSocket socket;
|
||||
auto address = socket_bind_listen(socket, listen_host, config().getInt("tcp_port"));
|
||||
listen_port = config().getInt("tcp_port");
|
||||
auto address = socket_bind_listen(socket, listen_host, listen_port);
|
||||
socket.setReceiveTimeout(settings.receive_timeout);
|
||||
socket.setSendTimeout(settings.send_timeout);
|
||||
servers.emplace_back(std::make_unique<Poco::Net::TCPServer>(
|
||||
@ -671,7 +676,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
{
|
||||
#if USE_POCO_NETSSL
|
||||
Poco::Net::SecureServerSocket socket;
|
||||
auto address = socket_bind_listen(socket, listen_host, config().getInt("tcp_port_secure"), /* secure = */ true);
|
||||
listen_port = config().getInt("tcp_port_secure");
|
||||
auto address = socket_bind_listen(socket, listen_host, listen_port, /* secure = */ true);
|
||||
socket.setReceiveTimeout(settings.receive_timeout);
|
||||
socket.setSendTimeout(settings.send_timeout);
|
||||
servers.emplace_back(std::make_unique<Poco::Net::TCPServer>(
|
||||
@ -694,7 +700,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
if (config().has("interserver_http_port"))
|
||||
{
|
||||
Poco::Net::ServerSocket socket;
|
||||
auto address = socket_bind_listen(socket, listen_host, config().getInt("interserver_http_port"));
|
||||
listen_port = config().getInt("interserver_http_port");
|
||||
auto address = socket_bind_listen(socket, listen_host, listen_port);
|
||||
socket.setReceiveTimeout(settings.http_receive_timeout);
|
||||
socket.setSendTimeout(settings.http_send_timeout);
|
||||
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
|
||||
@ -710,7 +717,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
{
|
||||
#if USE_POCO_NETSSL
|
||||
Poco::Net::SecureServerSocket socket;
|
||||
auto address = socket_bind_listen(socket, listen_host, config().getInt("interserver_https_port"), /* secure = */ true);
|
||||
listen_port = config().getInt("interserver_https_port");
|
||||
auto address = socket_bind_listen(socket, listen_host, listen_port, /* secure = */ true);
|
||||
socket.setReceiveTimeout(settings.http_receive_timeout);
|
||||
socket.setSendTimeout(settings.http_send_timeout);
|
||||
servers.emplace_back(std::make_unique<Poco::Net::HTTPServer>(
|
||||
@ -726,16 +734,17 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
#endif
|
||||
}
|
||||
}
|
||||
catch (const Poco::Net::NetException & e)
|
||||
catch (const Poco::Exception & e)
|
||||
{
|
||||
std::string message = "Listen [" + listen_host + "]:" + std::to_string(listen_port) + " failed: " + std::to_string(e.code()) + ": " + e.what() + ": " + e.message();
|
||||
if (listen_try)
|
||||
LOG_ERROR(log, "Listen [" << listen_host << "]: " << e.code() << ": " << e.what() << ": " << e.message()
|
||||
LOG_ERROR(log, message
|
||||
<< " If it is an IPv6 or IPv4 address and your host has disabled IPv6 or IPv4, then consider to "
|
||||
"specify not disabled IPv4 or IPv6 address to listen in <listen_host> element of configuration "
|
||||
"file. Example for disabled IPv6: <listen_host>0.0.0.0</listen_host> ."
|
||||
" Example for disabled IPv4: <listen_host>::</listen_host>");
|
||||
else
|
||||
throw;
|
||||
throw Exception{message, ErrorCodes::NETWORK_ERROR};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -370,8 +370,8 @@ void TCPHandler::processInsertQuery(const Settings & global_settings)
|
||||
if (client_revision >= DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA)
|
||||
{
|
||||
const auto & db_and_table = query_context->getInsertionTable();
|
||||
if (auto * columns = ColumnsDescription::loadFromContext(*query_context, db_and_table.first, db_and_table.second))
|
||||
sendTableColumns(*columns);
|
||||
if (query_context->getSettingsRef().input_format_defaults_for_omitted_fields)
|
||||
sendTableColumns(query_context->getTable(db_and_table.first, db_and_table.second)->getColumns());
|
||||
}
|
||||
|
||||
/// Send block to the client - table structure.
|
||||
|
474
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp
Normal file
474
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp
Normal file
@ -0,0 +1,474 @@
|
||||
#include "AggregateFunctionMLMethod.h"
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/castColumn.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include "AggregateFunctionFactory.h"
|
||||
#include "FactoryHelpers.h"
|
||||
#include "Helpers.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace
|
||||
{
|
||||
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
|
||||
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
|
||||
template <class Method>
|
||||
AggregateFunctionPtr
|
||||
createAggregateFunctionMLMethod(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
if (parameters.size() > 4)
|
||||
throw Exception(
|
||||
"Aggregate function " + name
|
||||
+ " requires at most four parameters: learning_rate, l2_regularization_coef, mini-batch size and weights_updater "
|
||||
"method",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (argument_types.size() < 2)
|
||||
throw Exception(
|
||||
"Aggregate function " + name + " requires at least two arguments: target and model's parameters",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
for (size_t i = 0; i < argument_types.size(); ++i)
|
||||
{
|
||||
if (!isNumber(argument_types[i]))
|
||||
throw Exception(
|
||||
"Argument " + std::to_string(i) + " of type " + argument_types[i]->getName()
|
||||
+ " must be numeric for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
/// Such default parameters were picked because they did good on some tests,
|
||||
/// though it still requires to fit parameters to achieve better result
|
||||
auto learning_rate = Float64(0.01);
|
||||
auto l2_reg_coef = Float64(0.01);
|
||||
UInt32 batch_size = 1;
|
||||
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
|
||||
if (!parameters.empty())
|
||||
{
|
||||
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
|
||||
}
|
||||
if (parameters.size() > 1)
|
||||
{
|
||||
l2_reg_coef = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[1]);
|
||||
}
|
||||
if (parameters.size() > 2)
|
||||
{
|
||||
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]);
|
||||
}
|
||||
if (parameters.size() > 3)
|
||||
{
|
||||
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'")
|
||||
{
|
||||
weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'")
|
||||
{
|
||||
weights_updater = std::make_shared<Momentum>();
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'")
|
||||
{
|
||||
weights_updater = std::make_shared<Nesterov>();
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
if (std::is_same<Method, FuncLinearRegression>::value)
|
||||
{
|
||||
gradient_computer = std::make_shared<LinearRegression>();
|
||||
}
|
||||
else if (std::is_same<Method, FuncLogisticRegression>::value)
|
||||
{
|
||||
gradient_computer = std::make_shared<LogisticRegression>();
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
return std::make_shared<Method>(
|
||||
argument_types.size() - 1,
|
||||
gradient_computer,
|
||||
weights_updater,
|
||||
learning_rate,
|
||||
l2_reg_coef,
|
||||
batch_size,
|
||||
argument_types,
|
||||
parameters);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
|
||||
}
|
||||
|
||||
LinearModelData::LinearModelData(
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
UInt32 param_num,
|
||||
UInt32 batch_capacity,
|
||||
std::shared_ptr<DB::IGradientComputer> gradient_computer,
|
||||
std::shared_ptr<DB::IWeightsUpdater> weights_updater)
|
||||
: learning_rate(learning_rate)
|
||||
, l2_reg_coef(l2_reg_coef)
|
||||
, batch_capacity(batch_capacity)
|
||||
, batch_size(0)
|
||||
, gradient_computer(std::move(gradient_computer))
|
||||
, weights_updater(std::move(weights_updater))
|
||||
{
|
||||
weights.resize(param_num, Float64{0.0});
|
||||
gradient_batch.resize(param_num + 1, Float64{0.0});
|
||||
}
|
||||
|
||||
void LinearModelData::update_state()
|
||||
{
|
||||
if (batch_size == 0)
|
||||
return;
|
||||
|
||||
weights_updater->update(batch_size, weights, bias, gradient_batch);
|
||||
batch_size = 0;
|
||||
++iter_num;
|
||||
gradient_batch.assign(gradient_batch.size(), Float64{0.0});
|
||||
}
|
||||
|
||||
void LinearModelData::predict(
|
||||
ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const
|
||||
{
|
||||
gradient_computer->predict(container, block, arguments, weights, bias, context);
|
||||
}
|
||||
|
||||
void LinearModelData::read(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(bias, buf);
|
||||
readBinary(weights, buf);
|
||||
readBinary(iter_num, buf);
|
||||
readBinary(gradient_batch, buf);
|
||||
readBinary(batch_size, buf);
|
||||
weights_updater->read(buf);
|
||||
}
|
||||
|
||||
void LinearModelData::write(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(bias, buf);
|
||||
writeBinary(weights, buf);
|
||||
writeBinary(iter_num, buf);
|
||||
writeBinary(gradient_batch, buf);
|
||||
writeBinary(batch_size, buf);
|
||||
weights_updater->write(buf);
|
||||
}
|
||||
|
||||
void LinearModelData::merge(const DB::LinearModelData & rhs)
|
||||
{
|
||||
if (iter_num == 0 && rhs.iter_num == 0)
|
||||
return;
|
||||
|
||||
update_state();
|
||||
/// can't update rhs state because it's constant
|
||||
|
||||
Float64 frac = (static_cast<Float64>(iter_num) * iter_num) / (iter_num * iter_num + rhs.iter_num * rhs.iter_num);
|
||||
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] = weights[i] * frac + rhs.weights[i] * (1 - frac);
|
||||
}
|
||||
bias = bias * frac + rhs.bias * (1 - frac);
|
||||
|
||||
iter_num += rhs.iter_num;
|
||||
weights_updater->merge(*rhs.weights_updater, frac, 1 - frac);
|
||||
}
|
||||
|
||||
void LinearModelData::add(const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
/// first column stores target; features start from (columns + 1)
|
||||
const auto target = (*columns[0])[row_num].get<Float64>();
|
||||
/// Here we have columns + 1 as first column corresponds to target value, and others - to features
|
||||
weights_updater->add_to_batch(
|
||||
gradient_batch, *gradient_computer, weights, bias, learning_rate, l2_reg_coef, target, columns + 1, row_num);
|
||||
|
||||
++batch_size;
|
||||
if (batch_size == batch_capacity)
|
||||
{
|
||||
update_state();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Nesterov::read(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(accumulated_gradient, buf);
|
||||
}
|
||||
|
||||
void Nesterov::write(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(accumulated_gradient, buf);
|
||||
}
|
||||
|
||||
void Nesterov::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
|
||||
{
|
||||
auto & nesterov_rhs = static_cast<const Nesterov &>(rhs);
|
||||
for (size_t i = 0; i < accumulated_gradient.size(); ++i)
|
||||
{
|
||||
accumulated_gradient[i] = accumulated_gradient[i] * frac + nesterov_rhs.accumulated_gradient[i] * rhs_frac;
|
||||
}
|
||||
}
|
||||
|
||||
void Nesterov::update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
|
||||
{
|
||||
if (accumulated_gradient.empty())
|
||||
{
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < batch_gradient.size(); ++i)
|
||||
{
|
||||
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size;
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += accumulated_gradient[i];
|
||||
}
|
||||
bias += accumulated_gradient[weights.size()];
|
||||
}
|
||||
|
||||
void Nesterov::add_to_batch(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
IGradientComputer & gradient_computer,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num)
|
||||
{
|
||||
if (accumulated_gradient.empty())
|
||||
{
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
|
||||
std::vector<Float64> shifted_weights(weights.size());
|
||||
for (size_t i = 0; i != shifted_weights.size(); ++i)
|
||||
{
|
||||
shifted_weights[i] = weights[i] + accumulated_gradient[i] * alpha_;
|
||||
}
|
||||
auto shifted_bias = bias + accumulated_gradient[weights.size()] * alpha_;
|
||||
|
||||
gradient_computer.compute(batch_gradient, shifted_weights, shifted_bias, learning_rate, l2_reg_coef, target, columns, row_num);
|
||||
}
|
||||
|
||||
void Momentum::read(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(accumulated_gradient, buf);
|
||||
}
|
||||
|
||||
void Momentum::write(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(accumulated_gradient, buf);
|
||||
}
|
||||
|
||||
void Momentum::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
|
||||
{
|
||||
auto & momentum_rhs = static_cast<const Momentum &>(rhs);
|
||||
for (size_t i = 0; i < accumulated_gradient.size(); ++i)
|
||||
{
|
||||
accumulated_gradient[i] = accumulated_gradient[i] * frac + momentum_rhs.accumulated_gradient[i] * rhs_frac;
|
||||
}
|
||||
}
|
||||
|
||||
void Momentum::update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
|
||||
{
|
||||
/// batch_size is already checked to be greater than 0
|
||||
if (accumulated_gradient.empty())
|
||||
{
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < batch_gradient.size(); ++i)
|
||||
{
|
||||
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size;
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += accumulated_gradient[i];
|
||||
}
|
||||
bias += accumulated_gradient[weights.size()];
|
||||
}
|
||||
|
||||
void StochasticGradientDescent::update(
|
||||
UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
|
||||
{
|
||||
/// batch_size is already checked to be greater than 0
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += batch_gradient[i] / batch_size;
|
||||
}
|
||||
bias += batch_gradient[weights.size()] / batch_size;
|
||||
}
|
||||
|
||||
void IWeightsUpdater::add_to_batch(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
IGradientComputer & gradient_computer,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num)
|
||||
{
|
||||
gradient_computer.compute(batch_gradient, weights, bias, learning_rate, l2_reg_coef, target, columns, row_num);
|
||||
}
|
||||
|
||||
void LogisticRegression::predict(
|
||||
ColumnVector<Float64>::Container & container,
|
||||
Block & block,
|
||||
const ColumnNumbers & arguments,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
const Context & context) const
|
||||
{
|
||||
size_t rows_num = block.rows();
|
||||
std::vector<Float64> results(rows_num, bias);
|
||||
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
|
||||
if (!isNumber(cur_col.type))
|
||||
{
|
||||
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
/// If column type is already Float64 then castColumn simply returns it
|
||||
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
|
||||
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
|
||||
|
||||
if (!features_column)
|
||||
{
|
||||
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
|
||||
}
|
||||
}
|
||||
|
||||
container.reserve(rows_num);
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
container.emplace_back(1 / (1 + exp(-results[row_num])));
|
||||
}
|
||||
}
|
||||
|
||||
void LogisticRegression::compute(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num)
|
||||
{
|
||||
Float64 derivative = bias;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
derivative += weights[i] * value;
|
||||
}
|
||||
derivative *= target;
|
||||
derivative = exp(derivative);
|
||||
|
||||
batch_gradient[weights.size()] += learning_rate * target / (derivative + 1);
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
batch_gradient[i] += learning_rate * target * value / (derivative + 1) - 2 * l2_reg_coef * weights[i];
|
||||
}
|
||||
}
|
||||
|
||||
void LinearRegression::predict(
|
||||
ColumnVector<Float64>::Container & container,
|
||||
Block & block,
|
||||
const ColumnNumbers & arguments,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
const Context & context) const
|
||||
{
|
||||
if (weights.size() + 1 != arguments.size())
|
||||
{
|
||||
throw Exception("In predict function number of arguments differs from the size of weights vector", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
size_t rows_num = block.rows();
|
||||
std::vector<Float64> results(rows_num, bias);
|
||||
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
|
||||
if (!isNumber(cur_col.type))
|
||||
{
|
||||
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
/// If column type is already Float64 then castColumn simply returns it
|
||||
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
|
||||
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
|
||||
|
||||
if (!features_column)
|
||||
{
|
||||
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
|
||||
}
|
||||
}
|
||||
|
||||
container.reserve(rows_num);
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
container.emplace_back(results[row_num]);
|
||||
}
|
||||
}
|
||||
|
||||
void LinearRegression::compute(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num)
|
||||
{
|
||||
Float64 derivative = (target - bias);
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
derivative -= weights[i] * value;
|
||||
}
|
||||
derivative *= (2 * learning_rate);
|
||||
|
||||
batch_gradient[weights.size()] += derivative;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
batch_gradient[i] += derivative * value - 2 * l2_reg_coef * weights[i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
330
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h
Normal file
330
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h
Normal file
@ -0,0 +1,330 @@
|
||||
#pragma once
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include "IAggregateFunction.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
/**
|
||||
GradientComputer class computes gradient according to its loss function
|
||||
*/
|
||||
class IGradientComputer
|
||||
{
|
||||
public:
|
||||
IGradientComputer() {}
|
||||
|
||||
virtual ~IGradientComputer() = default;
|
||||
|
||||
/// Adds computed gradient in new point (weights, bias) to batch_gradient
|
||||
virtual void compute(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num)
|
||||
= 0;
|
||||
|
||||
virtual void predict(
|
||||
ColumnVector<Float64>::Container & container,
|
||||
Block & block,
|
||||
const ColumnNumbers & arguments,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
const Context & context) const = 0;
|
||||
};
|
||||
|
||||
|
||||
class LinearRegression : public IGradientComputer
|
||||
{
|
||||
public:
|
||||
LinearRegression() {}
|
||||
|
||||
void compute(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num) override;
|
||||
|
||||
void predict(
|
||||
ColumnVector<Float64>::Container & container,
|
||||
Block & block,
|
||||
const ColumnNumbers & arguments,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
const Context & context) const override;
|
||||
};
|
||||
|
||||
|
||||
class LogisticRegression : public IGradientComputer
|
||||
{
|
||||
public:
|
||||
LogisticRegression() {}
|
||||
|
||||
void compute(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num) override;
|
||||
|
||||
void predict(
|
||||
ColumnVector<Float64>::Container & container,
|
||||
Block & block,
|
||||
const ColumnNumbers & arguments,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
const Context & context) const override;
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* IWeightsUpdater class defines the way to update current weights
|
||||
* and uses GradientComputer class on each iteration
|
||||
*/
|
||||
class IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
virtual ~IWeightsUpdater() = default;
|
||||
|
||||
/// Calls GradientComputer to update current mini-batch
|
||||
virtual void add_to_batch(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
IGradientComputer & gradient_computer,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num);
|
||||
|
||||
/// Updates current weights according to the gradient from the last mini-batch
|
||||
virtual void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
|
||||
|
||||
/// Used during the merge of two states
|
||||
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
|
||||
|
||||
/// Used for serialization when necessary
|
||||
virtual void write(WriteBuffer &) const {}
|
||||
|
||||
/// Used for serialization when necessary
|
||||
virtual void read(ReadBuffer &) {}
|
||||
};
|
||||
|
||||
|
||||
class StochasticGradientDescent : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
|
||||
};
|
||||
|
||||
|
||||
class Momentum : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Momentum() {}
|
||||
|
||||
Momentum(Float64 alpha) : alpha_(alpha) {}
|
||||
|
||||
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
|
||||
|
||||
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
|
||||
|
||||
void write(WriteBuffer & buf) const override;
|
||||
|
||||
void read(ReadBuffer & buf) override;
|
||||
|
||||
private:
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> accumulated_gradient;
|
||||
};
|
||||
|
||||
|
||||
class Nesterov : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Nesterov() {}
|
||||
|
||||
Nesterov(Float64 alpha) : alpha_(alpha) {}
|
||||
|
||||
void add_to_batch(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
IGradientComputer & gradient_computer,
|
||||
const std::vector<Float64> & weights,
|
||||
Float64 bias,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num) override;
|
||||
|
||||
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
|
||||
|
||||
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
|
||||
|
||||
void write(WriteBuffer & buf) const override;
|
||||
|
||||
void read(ReadBuffer & buf) override;
|
||||
|
||||
private:
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> accumulated_gradient;
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* LinearModelData is a class which manages current state of learning
|
||||
*/
|
||||
class LinearModelData
|
||||
{
|
||||
public:
|
||||
LinearModelData() {}
|
||||
|
||||
LinearModelData(
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
UInt32 param_num,
|
||||
UInt32 batch_capacity,
|
||||
std::shared_ptr<IGradientComputer> gradient_computer,
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater);
|
||||
|
||||
void add(const IColumn ** columns, size_t row_num);
|
||||
|
||||
void merge(const LinearModelData & rhs);
|
||||
|
||||
void write(WriteBuffer & buf) const;
|
||||
|
||||
void read(ReadBuffer & buf);
|
||||
|
||||
void
|
||||
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
|
||||
|
||||
private:
|
||||
std::vector<Float64> weights;
|
||||
Float64 bias{0.0};
|
||||
|
||||
Float64 learning_rate;
|
||||
Float64 l2_reg_coef;
|
||||
UInt32 batch_capacity;
|
||||
|
||||
UInt32 iter_num = 0;
|
||||
std::vector<Float64> gradient_batch;
|
||||
UInt32 batch_size;
|
||||
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||
|
||||
/**
|
||||
* The function is called when we want to flush current batch and update our weights
|
||||
*/
|
||||
void update_state();
|
||||
};
|
||||
|
||||
|
||||
template <
|
||||
/// Implemented Machine Learning method
|
||||
typename Data,
|
||||
/// Name of the method
|
||||
typename Name>
|
||||
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
|
||||
{
|
||||
public:
|
||||
String getName() const override { return Name::name; }
|
||||
|
||||
explicit AggregateFunctionMLMethod(
|
||||
UInt32 param_num,
|
||||
std::shared_ptr<IGradientComputer> gradient_computer,
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
UInt32 batch_size,
|
||||
const DataTypes & arguments_types,
|
||||
const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params)
|
||||
, param_num(param_num)
|
||||
, learning_rate(learning_rate)
|
||||
, l2_reg_coef(l2_reg_coef)
|
||||
, batch_size(batch_size)
|
||||
, gradient_computer(std::move(gradient_computer))
|
||||
, weights_updater(std::move(weights_updater))
|
||||
{
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).add(columns, row_num);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); }
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).write(buf); }
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).read(buf); }
|
||||
|
||||
void predictValues(
|
||||
ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments, const Context & context) const override
|
||||
{
|
||||
if (arguments.size() != param_num + 1)
|
||||
throw Exception(
|
||||
"Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size())
|
||||
+ ". Required: " + std::to_string(param_num + 1),
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
auto & column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||
|
||||
this->data(place).predict(column.getData(), block, arguments, context);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
std::ignore = place;
|
||||
std::ignore = to;
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
|
||||
private:
|
||||
UInt32 param_num;
|
||||
Float64 learning_rate;
|
||||
Float64 l2_reg_coef;
|
||||
UInt32 batch_size;
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||
};
|
||||
|
||||
struct NameLinearRegression
|
||||
{
|
||||
static constexpr auto name = "LinearRegression";
|
||||
};
|
||||
struct NameLogisticRegression
|
||||
{
|
||||
static constexpr auto name = "LogisticRegression";
|
||||
};
|
||||
}
|
30
dbms/src/AggregateFunctions/AggregateFunctionTSGroupSum.cpp
Normal file
30
dbms/src/AggregateFunctions/AggregateFunctionTSGroupSum.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
#include "AggregateFunctionTSGroupSum.h"
|
||||
#include "AggregateFunctionFactory.h"
|
||||
#include "FactoryHelpers.h"
|
||||
#include "Helpers.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace
|
||||
{
|
||||
template <bool rate>
|
||||
AggregateFunctionPtr createAggregateFunctionTSgroupSum(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
||||
if (arguments.size() < 3)
|
||||
throw Exception("Not enough event arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
return std::make_shared<AggregateFunctionTSgroupSum<rate>>(arguments);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionTSgroupSum(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("TSgroupSum", createAggregateFunctionTSgroupSum<false>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("TSgroupRateSum", createAggregateFunctionTSgroupSum<true>, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
287
dbms/src/AggregateFunctions/AggregateFunctionTSGroupSum.h
Normal file
287
dbms/src/AggregateFunctions/AggregateFunctionTSGroupSum.h
Normal file
@ -0,0 +1,287 @@
|
||||
#pragma once
|
||||
|
||||
#include <bitset>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <ext/range.h>
|
||||
#include "IAggregateFunction.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
|
||||
}
|
||||
template <bool rate>
|
||||
struct AggregateFunctionTSgroupSumData
|
||||
{
|
||||
using DataPoint = std::pair<Int64, Float64>;
|
||||
struct Points
|
||||
{
|
||||
using Dps = std::queue<DataPoint>;
|
||||
Dps dps;
|
||||
void add(Int64 t, Float64 v)
|
||||
{
|
||||
dps.push(std::make_pair(t, v));
|
||||
if (dps.size() > 2)
|
||||
dps.pop();
|
||||
}
|
||||
Float64 getval(Int64 t)
|
||||
{
|
||||
Int64 t1, t2;
|
||||
Float64 v1, v2;
|
||||
if (rate)
|
||||
{
|
||||
if (dps.size() < 2)
|
||||
return 0;
|
||||
t1 = dps.back().first;
|
||||
t2 = dps.front().first;
|
||||
v1 = dps.back().second;
|
||||
v2 = dps.front().second;
|
||||
return (v1 - v2) / Float64(t1 - t2);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (dps.size() == 1 && t == dps.front().first)
|
||||
return dps.front().second;
|
||||
t1 = dps.back().first;
|
||||
t2 = dps.front().first;
|
||||
v1 = dps.back().second;
|
||||
v2 = dps.front().second;
|
||||
return v2 + ((v1 - v2) * Float64(t - t2)) / Float64(t1 - t2);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr size_t bytes_on_stack = 128;
|
||||
typedef std::map<UInt64, Points> Series;
|
||||
typedef PODArray<DataPoint, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>> AggSeries;
|
||||
Series ss;
|
||||
AggSeries result;
|
||||
|
||||
void add(UInt64 uid, Int64 t, Float64 v)
|
||||
{ //suppose t is coming asc
|
||||
typename Series::iterator it_ss;
|
||||
if (ss.count(uid) == 0)
|
||||
{ //time series not exist, insert new one
|
||||
Points tmp;
|
||||
tmp.add(t, v);
|
||||
ss.emplace(uid, tmp);
|
||||
it_ss = ss.find(uid);
|
||||
}
|
||||
else
|
||||
{
|
||||
it_ss = ss.find(uid);
|
||||
it_ss->second.add(t, v);
|
||||
}
|
||||
if (result.size() > 0 && t < result.back().first)
|
||||
throw Exception{"TSgroupSum or TSgroupRateSum must order by timestamp asc!!!", ErrorCodes::LOGICAL_ERROR};
|
||||
if (result.size() > 0 && t == result.back().first)
|
||||
{
|
||||
//do not add new point
|
||||
if (rate)
|
||||
result.back().second += it_ss->second.getval(t);
|
||||
else
|
||||
result.back().second += v;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (rate)
|
||||
result.emplace_back(std::make_pair(t, it_ss->second.getval(t)));
|
||||
else
|
||||
result.emplace_back(std::make_pair(t, v));
|
||||
}
|
||||
size_t i = result.size() - 1;
|
||||
//reverse find out the index of timestamp that more than previous timestamp of t
|
||||
while (result[i].first > it_ss->second.dps.front().first && i >= 0)
|
||||
i--;
|
||||
|
||||
i++;
|
||||
while (i < result.size() - 1)
|
||||
{
|
||||
result[i].second += it_ss->second.getval(result[i].first);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionTSgroupSumData & other)
|
||||
{
|
||||
//if ts has overlap, then aggregate two series by interpolation;
|
||||
AggSeries tmp;
|
||||
tmp.reserve(other.result.size() + result.size());
|
||||
size_t i = 0, j = 0;
|
||||
Int64 t1, t2;
|
||||
Float64 v1, v2;
|
||||
while (i < result.size() && j < other.result.size())
|
||||
{
|
||||
if (result[i].first < other.result[j].first)
|
||||
{
|
||||
if (j == 0)
|
||||
{
|
||||
tmp.emplace_back(result[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
t1 = other.result[j].first;
|
||||
t2 = other.result[j - 1].first;
|
||||
v1 = other.result[j].second;
|
||||
v2 = other.result[j - 1].second;
|
||||
Float64 value = result[i].second + v2 + (v1 - v2) * (Float64(result[i].first - t2)) / Float64(t1 - t2);
|
||||
tmp.emplace_back(std::make_pair(result[i].first, value));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
else if (result[i].first > other.result[j].first)
|
||||
{
|
||||
if (i == 0)
|
||||
{
|
||||
tmp.emplace_back(other.result[j]);
|
||||
}
|
||||
else
|
||||
{
|
||||
t1 = result[i].first;
|
||||
t2 = result[i - 1].first;
|
||||
v1 = result[i].second;
|
||||
v2 = result[i - 1].second;
|
||||
Float64 value = other.result[j].second + v2 + (v1 - v2) * (Float64(other.result[j].first - t2)) / Float64(t1 - t2);
|
||||
tmp.emplace_back(std::make_pair(other.result[j].first, value));
|
||||
}
|
||||
j++;
|
||||
}
|
||||
else
|
||||
{
|
||||
tmp.emplace_back(std::make_pair(result[i].first, result[i].second + other.result[j].second));
|
||||
i++;
|
||||
j++;
|
||||
}
|
||||
}
|
||||
while (i < result.size())
|
||||
{
|
||||
tmp.emplace_back(result[i]);
|
||||
i++;
|
||||
}
|
||||
while (j < other.result.size())
|
||||
{
|
||||
tmp.push_back(other.result[j]);
|
||||
j++;
|
||||
}
|
||||
swap(result, tmp);
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
size_t size = result.size();
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(reinterpret_cast<const char *>(result.data()), sizeof(result[0]));
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
result.resize(size);
|
||||
buf.read(reinterpret_cast<char *>(result.data()), size * sizeof(result[0]));
|
||||
}
|
||||
};
|
||||
template <bool rate>
|
||||
class AggregateFunctionTSgroupSum final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionTSgroupSumData<rate>, AggregateFunctionTSgroupSum<rate>>
|
||||
{
|
||||
private:
|
||||
public:
|
||||
String getName() const override { return rate ? "TSgroupRateSum" : "TSgroupSum"; }
|
||||
|
||||
AggregateFunctionTSgroupSum(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTSgroupSumData<rate>, AggregateFunctionTSgroupSum<rate>>(arguments, {})
|
||||
{
|
||||
if (!WhichDataType(arguments[0].get()).isUInt64())
|
||||
throw Exception{"Illegal type " + arguments[0].get()->getName() + " of argument 1 of aggregate function " + getName()
|
||||
+ ", must be UInt64",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
if (!WhichDataType(arguments[1].get()).isInt64())
|
||||
throw Exception{"Illegal type " + arguments[1].get()->getName() + " of argument 2 of aggregate function " + getName()
|
||||
+ ", must be Int64",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
if (!WhichDataType(arguments[2].get()).isFloat64())
|
||||
throw Exception{"Illegal type " + arguments[2].get()->getName() + " of argument 3 of aggregate function " + getName()
|
||||
+ ", must be Float64",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
auto datatypes = std::vector<DataTypePtr>();
|
||||
datatypes.push_back(std::make_shared<DataTypeInt64>());
|
||||
datatypes.push_back(std::make_shared<DataTypeFloat64>());
|
||||
|
||||
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(datatypes));
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
auto uid = static_cast<const ColumnVector<UInt64> *>(columns[0])->getData()[row_num];
|
||||
auto ts = static_cast<const ColumnVector<Int64> *>(columns[1])->getData()[row_num];
|
||||
auto val = static_cast<const ColumnVector<Float64> *>(columns[2])->getData()[row_num];
|
||||
if (uid && ts && val)
|
||||
{
|
||||
this->data(place).add(uid, ts, val);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); }
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).serialize(buf); }
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).deserialize(buf); }
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const auto & value = this->data(place).result;
|
||||
size_t size = value.size();
|
||||
|
||||
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
size_t old_size = offsets_to.back();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
if (size)
|
||||
{
|
||||
typename ColumnInt64::Container & ts_to
|
||||
= static_cast<ColumnInt64 &>(static_cast<ColumnTuple &>(arr_to.getData()).getColumn(0)).getData();
|
||||
typename ColumnFloat64::Container & val_to
|
||||
= static_cast<ColumnFloat64 &>(static_cast<ColumnTuple &>(arr_to.getData()).getColumn(1)).getData();
|
||||
ts_to.reserve(old_size + size);
|
||||
val_to.reserve(old_size + size);
|
||||
size_t i = 0;
|
||||
while (i < this->data(place).result.size())
|
||||
{
|
||||
ts_to.push_back(this->data(place).result[i].first);
|
||||
val_to.push_back(this->data(place).result[i].second);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
};
|
||||
}
|
@ -7,6 +7,8 @@
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <Core/Field.h>
|
||||
#include <Core/ColumnNumbers.h>
|
||||
#include <Core/Block.h>
|
||||
#include <Common/Exception.h>
|
||||
|
||||
|
||||
@ -92,6 +94,13 @@ public:
|
||||
/// Inserts results into a column.
|
||||
virtual void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const = 0;
|
||||
|
||||
/// This function is used for machine learning methods
|
||||
virtual void predictValues(ConstAggregateDataPtr /* place */, IColumn & /*to*/,
|
||||
Block & /*block*/, const ColumnNumbers & /*arguments*/, const Context & /*context*/) const
|
||||
{
|
||||
throw Exception("Method predictValues is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
/** Returns true for aggregate functions of type -State.
|
||||
* They are executed as other aggregate functions, but not finalized (return an aggregation state that can be combined with another).
|
||||
*/
|
||||
@ -149,7 +158,6 @@ protected:
|
||||
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
|
||||
|
||||
public:
|
||||
|
||||
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
|
||||
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}
|
||||
|
||||
|
@ -28,6 +28,7 @@ void registerAggregateFunctionTopK(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
|
||||
|
||||
@ -40,7 +41,7 @@ void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory
|
||||
|
||||
void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory);
|
||||
|
||||
void registerAggregateFunctionTSgroupSum(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctions()
|
||||
{
|
||||
{
|
||||
@ -69,6 +70,8 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionsMaxIntersections(factory);
|
||||
registerAggregateFunctionHistogram(factory);
|
||||
registerAggregateFunctionRetention(factory);
|
||||
registerAggregateFunctionTSgroupSum(factory);
|
||||
registerAggregateFunctionMLMethod(factory);
|
||||
registerAggregateFunctionEntropy(factory);
|
||||
registerAggregateFunctionLeastSqr(factory);
|
||||
}
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/Arena.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -18,6 +19,7 @@ namespace ErrorCodes
|
||||
{
|
||||
extern const int PARAMETER_OUT_OF_BOUND;
|
||||
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
|
||||
@ -33,6 +35,25 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
|
||||
arenas.push_back(arena_);
|
||||
}
|
||||
|
||||
/// This function is used in convertToValues() and predictValues()
|
||||
/// and is written here to avoid repetitions
|
||||
bool ColumnAggregateFunction::tryFinalizeAggregateFunction(MutableColumnPtr *res_) const
|
||||
{
|
||||
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
*res_ = std::move(res);
|
||||
return true;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
*res_ = std::move(res);
|
||||
return false;
|
||||
}
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
{
|
||||
/** If the aggregate function returns an unfinalized/unfinished state,
|
||||
@ -65,23 +86,46 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
* AggregateFunction(quantileTiming(0.5), UInt64)
|
||||
* into UInt16 - already finished result of `quantileTiming`.
|
||||
*/
|
||||
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
|
||||
/** Convertion function is used in convertToValues and predictValues
|
||||
* in the similar part of both functions
|
||||
*/
|
||||
|
||||
MutableColumnPtr res;
|
||||
if (tryFinalizeAggregateFunction(&res))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
for (auto val : data)
|
||||
func->insertResultInto(val, *res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
|
||||
{
|
||||
MutableColumnPtr res;
|
||||
tryFinalizeAggregateFunction(&res);
|
||||
|
||||
auto ML_function = func.get();
|
||||
if (ML_function)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
for (auto val : data)
|
||||
{
|
||||
ML_function->predictValues(val, *res, block, arguments, context);
|
||||
++row_num;
|
||||
}
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Illegal aggregate function is passed",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void ColumnAggregateFunction::ensureOwnership()
|
||||
{
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -117,6 +118,9 @@ public:
|
||||
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
|
||||
const char * getFamilyName() const override { return "AggregateFunction"; }
|
||||
|
||||
bool tryFinalizeAggregateFunction(MutableColumnPtr* res_) const;
|
||||
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const;
|
||||
|
||||
size_t size() const override
|
||||
{
|
||||
return getData().size();
|
||||
|
@ -571,41 +571,41 @@ ConfigProcessor::LoadedConfig ConfigProcessor::loadConfigWithZooKeeperIncludes(
|
||||
|
||||
void ConfigProcessor::savePreprocessedConfig(const LoadedConfig & loaded_config, std::string preprocessed_dir)
|
||||
{
|
||||
if (preprocessed_path.empty())
|
||||
try
|
||||
{
|
||||
auto new_path = loaded_config.config_path;
|
||||
if (new_path.substr(0, main_config_path.size()) == main_config_path)
|
||||
new_path.replace(0, main_config_path.size(), "");
|
||||
std::replace(new_path.begin(), new_path.end(), '/', '_');
|
||||
|
||||
if (preprocessed_dir.empty())
|
||||
if (preprocessed_path.empty())
|
||||
{
|
||||
if (!loaded_config.configuration->has("path"))
|
||||
auto new_path = loaded_config.config_path;
|
||||
if (new_path.substr(0, main_config_path.size()) == main_config_path)
|
||||
new_path.replace(0, main_config_path.size(), "");
|
||||
std::replace(new_path.begin(), new_path.end(), '/', '_');
|
||||
|
||||
if (preprocessed_dir.empty())
|
||||
{
|
||||
// Will use current directory
|
||||
auto parent_path = Poco::Path(loaded_config.config_path).makeParent();
|
||||
preprocessed_dir = parent_path.toString();
|
||||
Poco::Path poco_new_path(new_path);
|
||||
poco_new_path.setBaseName(poco_new_path.getBaseName() + PREPROCESSED_SUFFIX);
|
||||
new_path = poco_new_path.toString();
|
||||
if (!loaded_config.configuration->has("path"))
|
||||
{
|
||||
// Will use current directory
|
||||
auto parent_path = Poco::Path(loaded_config.config_path).makeParent();
|
||||
preprocessed_dir = parent_path.toString();
|
||||
Poco::Path poco_new_path(new_path);
|
||||
poco_new_path.setBaseName(poco_new_path.getBaseName() + PREPROCESSED_SUFFIX);
|
||||
new_path = poco_new_path.toString();
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocessed_dir = loaded_config.configuration->getString("path") + "/preprocessed_configs/";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocessed_dir = loaded_config.configuration->getString("path") + "/preprocessed_configs/";
|
||||
preprocessed_dir += "/preprocessed_configs/";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocessed_dir += "/preprocessed_configs/";
|
||||
}
|
||||
|
||||
preprocessed_path = preprocessed_dir + new_path;
|
||||
auto preprocessed_path_parent = Poco::Path(preprocessed_path).makeParent();
|
||||
if (!preprocessed_path_parent.toString().empty())
|
||||
Poco::File(preprocessed_path_parent).createDirectories();
|
||||
}
|
||||
try
|
||||
{
|
||||
preprocessed_path = preprocessed_dir + new_path;
|
||||
auto preprocessed_path_parent = Poco::Path(preprocessed_path).makeParent();
|
||||
if (!preprocessed_path_parent.toString().empty())
|
||||
Poco::File(preprocessed_path_parent).createDirectories();
|
||||
}
|
||||
DOMWriter().writeNode(preprocessed_path, loaded_config.preprocessed_xml);
|
||||
}
|
||||
catch (Poco::Exception & e)
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/OptimizedRegularExpression.h>
|
||||
|
||||
|
||||
#define MIN_LENGTH_FOR_STRSTR 3
|
||||
#define MAX_SUBPATTERNS 5
|
||||
|
||||
@ -211,20 +210,18 @@ void OptimizedRegularExpressionImpl<thread_safe>::analyze(
|
||||
{
|
||||
if (!has_alternative_on_depth_0)
|
||||
{
|
||||
/** We choose the non-alternative substring of the maximum length, among the prefixes,
|
||||
* or a non-alternative substring of maximum length.
|
||||
*/
|
||||
/// We choose the non-alternative substring of the maximum length for first search.
|
||||
|
||||
/// Tuning for typical usage domain
|
||||
auto tuning_strings_condition = [](const std::string & str)
|
||||
{
|
||||
return str != "://" && str != "http://" && str != "www" && str != "Windows ";
|
||||
};
|
||||
size_t max_length = 0;
|
||||
Substrings::const_iterator candidate_it = trivial_substrings.begin();
|
||||
for (Substrings::const_iterator it = trivial_substrings.begin(); it != trivial_substrings.end(); ++it)
|
||||
{
|
||||
if (((it->second == 0 && candidate_it->second != 0)
|
||||
|| ((it->second == 0) == (candidate_it->second == 0) && it->first.size() > max_length))
|
||||
/// Tuning for typical usage domain
|
||||
&& (it->first.size() > strlen("://") || strncmp(it->first.data(), "://", strlen("://")))
|
||||
&& (it->first.size() > strlen("http://") || strncmp(it->first.data(), "http", strlen("http")))
|
||||
&& (it->first.size() > strlen("www.") || strncmp(it->first.data(), "www", strlen("www")))
|
||||
&& (it->first.size() > strlen("Windows ") || strncmp(it->first.data(), "Windows ", strlen("Windows "))))
|
||||
if (it->first.size() > max_length && tuning_strings_condition(it->first))
|
||||
{
|
||||
max_length = it->first.size();
|
||||
candidate_it = it;
|
||||
|
@ -122,6 +122,9 @@ RWLockImpl::LockHolder RWLockImpl::getLock(RWLockImpl::Type type, const String &
|
||||
|
||||
LockHolder res(new LockHolderImpl(shared_from_this(), it_group, it_client));
|
||||
|
||||
/// Wait a notification until we will be the only in the group.
|
||||
it_group->cv.wait(lock, [&] () { return it_group == queue.begin(); });
|
||||
|
||||
/// Insert myself (weak_ptr to the holder) to threads set to implement recursive lock
|
||||
thread_to_holder.emplace(this_thread_id, res);
|
||||
res->thread_id = this_thread_id;
|
||||
@ -130,17 +133,6 @@ RWLockImpl::LockHolder RWLockImpl::getLock(RWLockImpl::Type type, const String &
|
||||
query_id_to_holder.emplace(query_id, res);
|
||||
res->query_id = query_id;
|
||||
|
||||
/// We are first, we should not wait anything
|
||||
/// If we are not the first client in the group, a notification could be already sent
|
||||
if (it_group == queue.begin())
|
||||
{
|
||||
finalize_metrics();
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Wait a notification
|
||||
it_group->cv.wait(lock, [&] () { return it_group == queue.begin(); });
|
||||
|
||||
finalize_metrics();
|
||||
return res;
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <boost/core/noncopyable.hpp>
|
||||
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
@ -43,70 +43,6 @@ struct RadixSortMallocAllocator
|
||||
};
|
||||
|
||||
|
||||
template <typename KeyBits>
|
||||
struct RadixSortIdentityTransform
|
||||
{
|
||||
static constexpr bool transform_is_simple = true;
|
||||
|
||||
static KeyBits forward(KeyBits x) { return x; }
|
||||
static KeyBits backward(KeyBits x) { return x; }
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename TElement>
|
||||
struct RadixSortUIntTraits
|
||||
{
|
||||
using Element = TElement;
|
||||
using Key = Element;
|
||||
using CountType = uint32_t;
|
||||
using KeyBits = Key;
|
||||
|
||||
static constexpr size_t PART_SIZE_BITS = 8;
|
||||
|
||||
using Transform = RadixSortIdentityTransform<KeyBits>;
|
||||
using Allocator = RadixSortMallocAllocator;
|
||||
|
||||
static Key & extractKey(Element & elem) { return elem; }
|
||||
|
||||
static bool less(Key x, Key y)
|
||||
{
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename KeyBits>
|
||||
struct RadixSortSignedTransform
|
||||
{
|
||||
static constexpr bool transform_is_simple = true;
|
||||
|
||||
static KeyBits forward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
|
||||
static KeyBits backward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
|
||||
};
|
||||
|
||||
template <typename TElement>
|
||||
struct RadixSortIntTraits
|
||||
{
|
||||
using Element = TElement;
|
||||
using Key = Element;
|
||||
using CountType = uint32_t;
|
||||
using KeyBits = std::make_unsigned_t<Key>;
|
||||
|
||||
static constexpr size_t PART_SIZE_BITS = 8;
|
||||
|
||||
using Transform = RadixSortSignedTransform<KeyBits>;
|
||||
using Allocator = RadixSortMallocAllocator;
|
||||
|
||||
static Key & extractKey(Element & elem) { return elem; }
|
||||
|
||||
static bool less(Key x, Key y)
|
||||
{
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/** A transformation that transforms the bit representation of a key into an unsigned integer number,
|
||||
* that the order relation over the keys will match the order relation over the obtained unsigned numbers.
|
||||
* For floats this conversion does the following:
|
||||
@ -154,7 +90,73 @@ struct RadixSortFloatTraits
|
||||
/// The function to get the key from an array element.
|
||||
static Key & extractKey(Element & elem) { return elem; }
|
||||
|
||||
// TODO: Correct handling of NaNs, NULLs, etc
|
||||
/// Used when fallback to comparison based sorting is needed.
|
||||
/// TODO: Correct handling of NaNs, NULLs, etc
|
||||
static bool less(Key x, Key y)
|
||||
{
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename KeyBits>
|
||||
struct RadixSortIdentityTransform
|
||||
{
|
||||
static constexpr bool transform_is_simple = true;
|
||||
|
||||
static KeyBits forward(KeyBits x) { return x; }
|
||||
static KeyBits backward(KeyBits x) { return x; }
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename TElement>
|
||||
struct RadixSortUIntTraits
|
||||
{
|
||||
using Element = TElement;
|
||||
using Key = Element;
|
||||
using CountType = uint32_t;
|
||||
using KeyBits = Key;
|
||||
|
||||
static constexpr size_t PART_SIZE_BITS = 8;
|
||||
|
||||
using Transform = RadixSortIdentityTransform<KeyBits>;
|
||||
using Allocator = RadixSortMallocAllocator;
|
||||
|
||||
static Key & extractKey(Element & elem) { return elem; }
|
||||
|
||||
static bool less(Key x, Key y)
|
||||
{
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename KeyBits>
|
||||
struct RadixSortSignedTransform
|
||||
{
|
||||
static constexpr bool transform_is_simple = true;
|
||||
|
||||
static KeyBits forward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
|
||||
static KeyBits backward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
|
||||
};
|
||||
|
||||
|
||||
template <typename TElement>
|
||||
struct RadixSortIntTraits
|
||||
{
|
||||
using Element = TElement;
|
||||
using Key = Element;
|
||||
using CountType = uint32_t;
|
||||
using KeyBits = std::make_unsigned_t<Key>;
|
||||
|
||||
static constexpr size_t PART_SIZE_BITS = 8;
|
||||
|
||||
using Transform = RadixSortSignedTransform<KeyBits>;
|
||||
using Allocator = RadixSortMallocAllocator;
|
||||
|
||||
static Key & extractKey(Element & elem) { return elem; }
|
||||
|
||||
static bool less(Key x, Key y)
|
||||
{
|
||||
return x < y;
|
||||
@ -220,9 +222,9 @@ private:
|
||||
* Puts elements to buckets based on PASS-th digit, then recursively calls insertion sort or itself on the buckets
|
||||
*/
|
||||
template <size_t PASS>
|
||||
static inline void radixSortMSDInternal(Element *arr, size_t size, size_t limit)
|
||||
static inline void radixSortMSDInternal(Element * arr, size_t size, size_t limit)
|
||||
{
|
||||
Element *last_list[HISTOGRAM_SIZE + 1];
|
||||
Element * last_list[HISTOGRAM_SIZE + 1];
|
||||
Element ** last = last_list + 1;
|
||||
size_t count[HISTOGRAM_SIZE] = {0};
|
||||
|
||||
@ -295,7 +297,7 @@ private:
|
||||
|
||||
// A helper to choose sorting algorithm based on array length
|
||||
template <size_t PASS>
|
||||
static inline void radixSortMSDInternalHelper(Element *arr, size_t size, size_t limit)
|
||||
static inline void radixSortMSDInternalHelper(Element * arr, size_t size, size_t limit)
|
||||
{
|
||||
if (size <= INSERTION_SORT_THRESHOLD)
|
||||
insertionSortInternal(arr, size);
|
||||
@ -304,8 +306,8 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
// Least significant digit radix sort (stable)
|
||||
static void executeLSD(Element *arr, size_t size)
|
||||
/// Least significant digit radix sort (stable)
|
||||
static void executeLSD(Element * arr, size_t size)
|
||||
{
|
||||
/// If the array is smaller than 256, then it is better to use another algorithm.
|
||||
|
||||
@ -400,7 +402,7 @@ public:
|
||||
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
static void executeMSD(Element *arr, size_t size, size_t limit)
|
||||
static void executeMSD(Element * arr, size_t size, size_t limit)
|
||||
{
|
||||
limit = std::min(limit, size);
|
||||
radixSortMSDInternalHelper<NUM_PASSES - 1>(arr, size, limit);
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <Common/UTF8Helpers.h>
|
||||
#include <Core/Defines.h>
|
||||
#include <ext/range.h>
|
||||
#include <Poco/UTF8Encoding.h>
|
||||
#include <Poco/Unicode.h>
|
||||
|
@ -25,6 +25,7 @@
|
||||
#cmakedefine01 USE_BROTLI
|
||||
#cmakedefine01 USE_SSL
|
||||
#cmakedefine01 USE_HYPERSCAN
|
||||
#cmakedefine01 USE_SIMDJSON
|
||||
#cmakedefine01 USE_LFALLOC
|
||||
#cmakedefine01 USE_LFALLOC_RANDOM_HINT
|
||||
|
||||
|
@ -1,14 +1,18 @@
|
||||
include(${ClickHouse_SOURCE_DIR}/cmake/dbms_glob_sources.cmake)
|
||||
add_headers_and_sources(clickhouse_compression .)
|
||||
add_library(clickhouse_compression ${clickhouse_compression_headers} ${clickhouse_compression_sources})
|
||||
target_link_libraries(clickhouse_compression PRIVATE clickhouse_parsers clickhouse_common_io ${ZSTD_LIBRARY} ${LZ4_LIBRARY} ${CITYHASH_LIBRARIES})
|
||||
target_link_libraries(clickhouse_compression PRIVATE clickhouse_parsers clickhouse_common_io ${LZ4_LIBRARY} ${CITYHASH_LIBRARIES})
|
||||
if(ZSTD_LIBRARY)
|
||||
target_link_libraries(clickhouse_compression PRIVATE ${ZSTD_LIBRARY})
|
||||
endif()
|
||||
|
||||
target_include_directories(clickhouse_compression PUBLIC ${DBMS_INCLUDE_DIR})
|
||||
target_include_directories(clickhouse_compression SYSTEM PUBLIC ${PCG_RANDOM_INCLUDE_DIR})
|
||||
|
||||
if (NOT USE_INTERNAL_LZ4_LIBRARY)
|
||||
target_include_directories(clickhouse_compression SYSTEM BEFORE PRIVATE ${LZ4_INCLUDE_DIR})
|
||||
endif ()
|
||||
if (NOT USE_INTERNAL_ZSTD_LIBRARY)
|
||||
if (NOT USE_INTERNAL_ZSTD_LIBRARY AND ZSTD_INCLUDE_DIR)
|
||||
target_include_directories(clickhouse_compression SYSTEM BEFORE PRIVATE ${ZSTD_INCLUDE_DIR})
|
||||
endif ()
|
||||
|
||||
|
@ -123,3 +123,7 @@
|
||||
#else
|
||||
#define OPTIMIZE(x)
|
||||
#endif
|
||||
|
||||
/// This number is only used for distributed version compatible.
|
||||
/// It could be any magic number.
|
||||
#define DBMS_DISTRIBUTED_SENDS_MAGIC_NUMBER 0xCAFECABE
|
||||
|
@ -109,5 +109,4 @@ void Settings::addProgramOptions(boost::program_options::options_description & o
|
||||
Settings::getDescription(index).data)));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -50,6 +50,7 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingUInt64, min_insert_block_size_rows, DEFAULT_INSERT_BLOCK_SIZE, "Squash blocks passed to INSERT query to specified size in rows, if blocks are not big enough.") \
|
||||
M(SettingUInt64, min_insert_block_size_bytes, (DEFAULT_INSERT_BLOCK_SIZE * 256), "Squash blocks passed to INSERT query to specified size in bytes, if blocks are not big enough.") \
|
||||
M(SettingMaxThreads, max_threads, 0, "The maximum number of threads to execute the request. By default, it is determined automatically.") \
|
||||
M(SettingMaxThreads, max_alter_threads, 0, "The maximum number of threads to execute the ALTER requests. By default, it is determined automatically.") \
|
||||
M(SettingUInt64, max_read_buffer_size, DBMS_DEFAULT_BUFFER_SIZE, "The maximum size of the buffer to read from the filesystem.") \
|
||||
M(SettingUInt64, max_distributed_connections, 1024, "The maximum number of connections for distributed processing of one query (should be greater than max_threads).") \
|
||||
M(SettingUInt64, max_query_size, 262144, "Which part of the query can be read into RAM for parsing (the remaining data for INSERT, if any, is read later)") \
|
||||
@ -205,6 +206,7 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingUInt64, insert_distributed_timeout, 0, "Timeout for insert query into distributed. Setting is used only with insert_distributed_sync enabled. Zero value means no timeout.") \
|
||||
M(SettingInt64, distributed_ddl_task_timeout, 180, "Timeout for DDL query responses from all hosts in cluster. Negative value means infinite.") \
|
||||
M(SettingMilliseconds, stream_flush_interval_ms, 7500, "Timeout for flushing data from streaming storages.") \
|
||||
M(SettingMilliseconds, stream_poll_timeout_ms, 500, "Timeout for polling data from streaming storages.") \
|
||||
M(SettingString, format_schema, "", "Schema identifier (used by schema-based formats)") \
|
||||
M(SettingBool, insert_allow_materialized_columns, 0, "If setting is enabled, Allow materialized columns in INSERT.") \
|
||||
M(SettingSeconds, http_connection_timeout, DEFAULT_HTTP_READ_BUFFER_CONNECTION_TIMEOUT, "HTTP connection timeout.") \
|
||||
|
@ -428,7 +428,7 @@ public:
|
||||
const const_reference & operator *() const { return ref; }
|
||||
const const_reference * operator ->() const { return &ref; }
|
||||
const_iterator & operator ++() { ++ref.member; return *this; }
|
||||
const_iterator & operator ++(int) { const_iterator tmp = *this; ++*this; return tmp; }
|
||||
const_iterator operator ++(int) { const_iterator tmp = *this; ++*this; return tmp; }
|
||||
bool operator ==(const const_iterator & rhs) const { return ref.member == rhs.ref.member && ref.collection == rhs.ref.collection; }
|
||||
bool operator !=(const const_iterator & rhs) const { return !(*this == rhs); }
|
||||
protected:
|
||||
@ -445,7 +445,7 @@ public:
|
||||
reference & operator *() const { return this->ref; }
|
||||
reference * operator ->() const { return &this->ref; }
|
||||
iterator & operator ++() { const_iterator::operator ++(); return *this; }
|
||||
iterator & operator ++(int) { iterator tmp = *this; ++*this; return tmp; }
|
||||
iterator operator ++(int) { iterator tmp = *this; ++*this; return tmp; }
|
||||
};
|
||||
|
||||
/// Returns the number of settings.
|
||||
|
@ -1,6 +1,8 @@
|
||||
#include <DataStreams/AggregatingSortedBlockInputStream.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/StringUtils/StringUtils.h>
|
||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -22,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
|
||||
ColumnWithTypeAndName & column = header.safeGetByPosition(i);
|
||||
|
||||
/// We leave only states of aggregate functions.
|
||||
if (!startsWith(column.type->getName(), "AggregateFunction"))
|
||||
if (!dynamic_cast<const DataTypeAggregateFunction *>(column.type.get()) && !dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
|
||||
{
|
||||
column_numbers_not_to_aggregate.push_back(i);
|
||||
continue;
|
||||
@ -40,7 +42,17 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
|
||||
continue;
|
||||
}
|
||||
|
||||
column_numbers_to_aggregate.push_back(i);
|
||||
if (auto simple_aggr = dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
|
||||
{
|
||||
// simple aggregate function
|
||||
SimpleAggregateDescription desc{simple_aggr->getFunction(), i};
|
||||
columns_to_simple_aggregate.emplace_back(std::move(desc));
|
||||
}
|
||||
else
|
||||
{
|
||||
// standard aggregate function
|
||||
column_numbers_to_aggregate.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,7 +103,11 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
|
||||
|
||||
/// if there are enough rows accumulated and the last one is calculated completely
|
||||
if (key_differs && merged_rows >= max_block_size)
|
||||
{
|
||||
/// Write the simple aggregation result for the previous group.
|
||||
insertSimpleAggregationResult(merged_columns);
|
||||
return;
|
||||
}
|
||||
|
||||
queue.pop();
|
||||
|
||||
@ -110,6 +126,14 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
|
||||
for (auto & column_to_aggregate : columns_to_aggregate)
|
||||
column_to_aggregate->insertDefault();
|
||||
|
||||
/// Write the simple aggregation result for the previous group.
|
||||
if (merged_rows > 0)
|
||||
insertSimpleAggregationResult(merged_columns);
|
||||
|
||||
/// Reset simple aggregation states for next row
|
||||
for (auto & desc : columns_to_simple_aggregate)
|
||||
desc.createState();
|
||||
|
||||
++merged_rows;
|
||||
}
|
||||
|
||||
@ -127,6 +151,9 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
|
||||
}
|
||||
}
|
||||
|
||||
/// Write the simple aggregation result for the previous group.
|
||||
insertSimpleAggregationResult(merged_columns);
|
||||
|
||||
finished = true;
|
||||
}
|
||||
|
||||
@ -138,6 +165,21 @@ void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor)
|
||||
size_t j = column_numbers_to_aggregate[i];
|
||||
columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos);
|
||||
}
|
||||
|
||||
for (auto & desc : columns_to_simple_aggregate)
|
||||
{
|
||||
auto & col = cursor->all_columns[desc.column_number];
|
||||
desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns)
|
||||
{
|
||||
for (auto & desc : columns_to_simple_aggregate)
|
||||
{
|
||||
desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]);
|
||||
desc.destroyState();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <DataStreams/MergingSortedBlockInputStream.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
#include <Common/AlignedBuffer.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -38,10 +39,13 @@ private:
|
||||
/// Read finished.
|
||||
bool finished = false;
|
||||
|
||||
struct SimpleAggregateDescription;
|
||||
|
||||
/// Columns with which numbers should be aggregated.
|
||||
ColumnNumbers column_numbers_to_aggregate;
|
||||
ColumnNumbers column_numbers_not_to_aggregate;
|
||||
std::vector<ColumnAggregateFunction *> columns_to_aggregate;
|
||||
std::vector<SimpleAggregateDescription> columns_to_simple_aggregate;
|
||||
|
||||
RowRef current_key; /// The current primary key.
|
||||
RowRef next_key; /// The primary key of the next row.
|
||||
@ -54,6 +58,53 @@ private:
|
||||
/** Extract all states of aggregate functions and merge them with the current group.
|
||||
*/
|
||||
void addRow(SortCursor & cursor);
|
||||
|
||||
/** Insert all values of current row for simple aggregate functions
|
||||
*/
|
||||
void insertSimpleAggregationResult(MutableColumns & merged_columns);
|
||||
|
||||
/// Stores information for aggregation of SimpleAggregateFunction columns
|
||||
struct SimpleAggregateDescription
|
||||
{
|
||||
/// An aggregate function 'anyLast', 'sum'...
|
||||
AggregateFunctionPtr function;
|
||||
IAggregateFunction::AddFunc add_function;
|
||||
size_t column_number;
|
||||
AlignedBuffer state;
|
||||
bool created = false;
|
||||
|
||||
SimpleAggregateDescription(const AggregateFunctionPtr & function_, const size_t column_number_) : function(function_), column_number(column_number_)
|
||||
{
|
||||
add_function = function->getAddressOfAddFunction();
|
||||
state.reset(function->sizeOfData(), function->alignOfData());
|
||||
}
|
||||
|
||||
void createState()
|
||||
{
|
||||
if (created)
|
||||
return;
|
||||
function->create(state.data());
|
||||
created = true;
|
||||
}
|
||||
|
||||
void destroyState()
|
||||
{
|
||||
if (!created)
|
||||
return;
|
||||
function->destroy(state.data());
|
||||
created = false;
|
||||
}
|
||||
|
||||
/// Explicitly destroy aggregation state if the stream is terminated
|
||||
~SimpleAggregateDescription()
|
||||
{
|
||||
destroyState();
|
||||
}
|
||||
|
||||
SimpleAggregateDescription() = default;
|
||||
SimpleAggregateDescription(SimpleAggregateDescription &&) = default;
|
||||
SimpleAggregateDescription(const SimpleAggregateDescription &) = delete;
|
||||
};
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -269,6 +269,11 @@ protected:
|
||||
children.push_back(child);
|
||||
}
|
||||
|
||||
/** Check limits.
|
||||
* But only those that can be checked within each separate stream.
|
||||
*/
|
||||
bool checkTimeLimit();
|
||||
|
||||
private:
|
||||
bool enabled_extremes = false;
|
||||
|
||||
@ -296,10 +301,9 @@ private:
|
||||
|
||||
void updateExtremes(Block & block);
|
||||
|
||||
/** Check limits and quotas.
|
||||
/** Check quotas.
|
||||
* But only those that can be checked within each separate stream.
|
||||
*/
|
||||
bool checkTimeLimit();
|
||||
void checkQuota(Block & block);
|
||||
|
||||
size_t checkDepthImpl(size_t max_depth, size_t level) const;
|
||||
|
@ -7,6 +7,8 @@
|
||||
#include <DataStreams/InputStreamFromASTInsertQuery.h>
|
||||
#include <DataStreams/AddingDefaultsBlockInputStream.h>
|
||||
#include <Storages/ColumnsDescription.h>
|
||||
#include <Storages/IStorage.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -49,10 +51,10 @@ InputStreamFromASTInsertQuery::InputStreamFromASTInsertQuery(
|
||||
|
||||
res_stream = context.getInputFormat(format, *input_buffer_contacenated, header, context.getSettings().max_insert_block_size);
|
||||
|
||||
auto columns_description = ColumnsDescription::loadFromContext(context, ast_insert_query->database, ast_insert_query->table);
|
||||
if (columns_description)
|
||||
if (context.getSettingsRef().input_format_defaults_for_omitted_fields)
|
||||
{
|
||||
auto column_defaults = columns_description->getDefaults();
|
||||
StoragePtr storage = context.getTable(ast_insert_query->database, ast_insert_query->table);
|
||||
auto column_defaults = storage->getColumns().getDefaults();
|
||||
if (!column_defaults.empty())
|
||||
res_stream = std::make_shared<AddingDefaultsBlockInputStream>(res_stream, column_defaults, context);
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <cstddef>
|
||||
#include <Core/Types.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,21 +12,21 @@ class WriteBuffer;
|
||||
struct FormatSettings;
|
||||
class IColumn;
|
||||
|
||||
/** Further refinment of the properties of data type.
|
||||
*
|
||||
* Contains methods for serialization/deserialization.
|
||||
* Implementations of this interface represent a data type domain (example: IPv4)
|
||||
* which is a refinement of the exsitgin type with a name and specific text
|
||||
* representation.
|
||||
*
|
||||
* IDataTypeDomain is totally immutable object. You can always share them.
|
||||
/** Allow to customize an existing data type and set a different name and/or text serialization/deserialization methods.
|
||||
* See use in IPv4 and IPv6 data types, and also in SimpleAggregateFunction.
|
||||
*/
|
||||
class IDataTypeDomain
|
||||
class IDataTypeCustomName
|
||||
{
|
||||
public:
|
||||
virtual ~IDataTypeDomain() {}
|
||||
virtual ~IDataTypeCustomName() {}
|
||||
|
||||
virtual const char* getName() const = 0;
|
||||
virtual String getName() const = 0;
|
||||
};
|
||||
|
||||
class IDataTypeCustomTextSerialization
|
||||
{
|
||||
public:
|
||||
virtual ~IDataTypeCustomTextSerialization() {}
|
||||
|
||||
/** Text serialization for displaying on a terminal or saving into a text file, and the like.
|
||||
* Without escaping or quoting.
|
||||
@ -56,4 +58,31 @@ public:
|
||||
virtual void serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const = 0;
|
||||
};
|
||||
|
||||
using DataTypeCustomNamePtr = std::unique_ptr<const IDataTypeCustomName>;
|
||||
using DataTypeCustomTextSerializationPtr = std::unique_ptr<const IDataTypeCustomTextSerialization>;
|
||||
|
||||
/** Describe a data type customization
|
||||
*/
|
||||
struct DataTypeCustomDesc
|
||||
{
|
||||
DataTypeCustomNamePtr name;
|
||||
DataTypeCustomTextSerializationPtr text_serialization;
|
||||
|
||||
DataTypeCustomDesc(DataTypeCustomNamePtr name_, DataTypeCustomTextSerializationPtr text_serialization_)
|
||||
: name(std::move(name_)), text_serialization(std::move(text_serialization_)) {}
|
||||
};
|
||||
|
||||
using DataTypeCustomDescPtr = std::unique_ptr<DataTypeCustomDesc>;
|
||||
|
||||
/** A simple implementation of IDataTypeCustomName
|
||||
*/
|
||||
class DataTypeCustomFixedName : public IDataTypeCustomName
|
||||
{
|
||||
private:
|
||||
String name;
|
||||
public:
|
||||
DataTypeCustomFixedName(String name_) : name(name_) {}
|
||||
String getName() const override { return name; }
|
||||
};
|
||||
|
||||
} // namespace DB
|
@ -1,9 +1,9 @@
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/formatIPv6.h>
|
||||
#include <DataTypes/DataTypeDomainWithSimpleSerialization.h>
|
||||
#include <DataTypes/DataTypeCustomSimpleTextSerialization.h>
|
||||
#include <DataTypes/DataTypeFactory.h>
|
||||
#include <DataTypes/IDataTypeDomain.h>
|
||||
#include <DataTypes/DataTypeCustom.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/FunctionsCoding.h>
|
||||
|
||||
@ -20,20 +20,15 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization
|
||||
class DataTypeCustomIPv4Serialization : public DataTypeCustomSimpleTextSerialization
|
||||
{
|
||||
public:
|
||||
const char * getName() const override
|
||||
{
|
||||
return "IPv4";
|
||||
}
|
||||
|
||||
void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override
|
||||
{
|
||||
const auto col = checkAndGetColumn<ColumnUInt32>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception("IPv4 type can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -48,7 +43,7 @@ public:
|
||||
ColumnUInt32 * col = typeid_cast<ColumnUInt32 *>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception("IPv4 type can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -63,20 +58,16 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization
|
||||
class DataTypeCustomIPv6Serialization : public DataTypeCustomSimpleTextSerialization
|
||||
{
|
||||
public:
|
||||
const char * getName() const override
|
||||
{
|
||||
return "IPv6";
|
||||
}
|
||||
|
||||
void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override
|
||||
{
|
||||
const auto col = checkAndGetColumn<ColumnFixedString>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception("IPv6 type domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -91,7 +82,7 @@ public:
|
||||
ColumnFixedString * col = typeid_cast<ColumnFixedString *>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception("IPv6 type domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -100,7 +91,7 @@ public:
|
||||
std::string ipv6_value(IPV6_BINARY_LENGTH, '\0');
|
||||
if (!parseIPv6(buffer, reinterpret_cast<unsigned char *>(ipv6_value.data())))
|
||||
{
|
||||
throw Exception(String("Invalid ") + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
|
||||
throw Exception("Invalid IPv6 value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
|
||||
}
|
||||
|
||||
col->insertString(ipv6_value);
|
||||
@ -111,8 +102,17 @@ public:
|
||||
|
||||
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory)
|
||||
{
|
||||
factory.registerDataTypeDomain("UInt32", std::make_unique<DataTypeDomainIPv4>());
|
||||
factory.registerDataTypeDomain("FixedString(16)", std::make_unique<DataTypeDomainIPv6>());
|
||||
factory.registerSimpleDataTypeCustom("IPv4", []
|
||||
{
|
||||
return std::make_pair(DataTypeFactory::instance().get("UInt32"),
|
||||
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv4"), std::make_unique<DataTypeCustomIPv4Serialization>()));
|
||||
});
|
||||
|
||||
factory.registerSimpleDataTypeCustom("IPv6", []
|
||||
{
|
||||
return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"),
|
||||
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv6"), std::make_unique<DataTypeCustomIPv6Serialization>()));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace DB
|
137
dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp
Normal file
137
dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp
Normal file
@ -0,0 +1,137 @@
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
|
||||
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
|
||||
#include <DataTypes/DataTypeLowCardinality.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeFactory.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
|
||||
#include <boost/algorithm/string/join.hpp>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int SYNTAX_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
static const std::vector<String> supported_functions{"any", "anyLast", "min", "max", "sum"};
|
||||
|
||||
|
||||
String DataTypeCustomSimpleAggregateFunction::getName() const
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "SimpleAggregateFunction(" << function->getName();
|
||||
|
||||
if (!parameters.empty())
|
||||
{
|
||||
stream << "(";
|
||||
for (size_t i = 0; i < parameters.size(); ++i)
|
||||
{
|
||||
if (i)
|
||||
stream << ", ";
|
||||
stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]);
|
||||
}
|
||||
stream << ")";
|
||||
}
|
||||
|
||||
for (const auto & argument_type : argument_types)
|
||||
stream << ", " << argument_type->getName();
|
||||
|
||||
stream << ")";
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
|
||||
static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & arguments)
|
||||
{
|
||||
String function_name;
|
||||
AggregateFunctionPtr function;
|
||||
DataTypes argument_types;
|
||||
Array params_row;
|
||||
|
||||
if (!arguments || arguments->children.empty())
|
||||
throw Exception("Data type SimpleAggregateFunction requires parameters: "
|
||||
"name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (const ASTFunction * parametric = arguments->children[0]->as<ASTFunction>())
|
||||
{
|
||||
if (parametric->parameters)
|
||||
throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR);
|
||||
function_name = parametric->name;
|
||||
|
||||
const ASTs & parameters = parametric->arguments->as<ASTExpressionList &>().children;
|
||||
params_row.resize(parameters.size());
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i)
|
||||
{
|
||||
const ASTLiteral * lit = parameters[i]->as<ASTLiteral>();
|
||||
if (!lit)
|
||||
throw Exception("Parameters to aggregate functions must be literals",
|
||||
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
|
||||
|
||||
params_row[i] = lit->value;
|
||||
}
|
||||
}
|
||||
else if (auto opt_name = getIdentifierName(arguments->children[0]))
|
||||
{
|
||||
function_name = *opt_name;
|
||||
}
|
||||
else if (arguments->children[0]->as<ASTLiteral>())
|
||||
{
|
||||
throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
else
|
||||
throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
for (size_t i = 1; i < arguments->children.size(); ++i)
|
||||
argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i]));
|
||||
|
||||
if (function_name.empty())
|
||||
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
|
||||
|
||||
// check function
|
||||
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))
|
||||
{
|
||||
throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","),
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
|
||||
|
||||
if (!function->getReturnType()->equals(*removeLowCardinality(storage_type)))
|
||||
{
|
||||
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(),
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
DataTypeCustomNamePtr custom_name = std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, argument_types, params_row);
|
||||
|
||||
return std::make_pair(storage_type, std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
|
||||
}
|
||||
|
||||
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory)
|
||||
{
|
||||
factory.registerDataTypeCustom("SimpleAggregateFunction", create);
|
||||
}
|
||||
|
||||
}
|
42
dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.h
Normal file
42
dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.h
Normal file
@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include <DataTypes/DataTypeCustom.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/** The type SimpleAggregateFunction(fct, type) is meant to be used in an AggregatingMergeTree. It behaves like a standard
|
||||
* data type but when rows are merged, an aggregation function is applied.
|
||||
*
|
||||
* The aggregation function is limited to simple functions whose merge state is the final result:
|
||||
* any, anyLast, min, max, sum
|
||||
*
|
||||
* Examples:
|
||||
*
|
||||
* SimpleAggregateFunction(sum, Nullable(Float64))
|
||||
* SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String)))
|
||||
* SimpleAggregateFunction(anyLast, IPv4)
|
||||
*
|
||||
* Technically, a standard IDataType is instanciated and customized with IDataTypeCustomName and DataTypeCustomDesc.
|
||||
*/
|
||||
|
||||
class DataTypeCustomSimpleAggregateFunction : public IDataTypeCustomName
|
||||
{
|
||||
private:
|
||||
const AggregateFunctionPtr function;
|
||||
const DataTypes argument_types;
|
||||
const Array parameters;
|
||||
|
||||
public:
|
||||
DataTypeCustomSimpleAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_)
|
||||
: function(function_), argument_types(argument_types_), parameters(parameters_) {}
|
||||
|
||||
const AggregateFunctionPtr getFunction() const { return function; }
|
||||
String getName() const override;
|
||||
};
|
||||
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
#include <DataTypes/DataTypeDomainWithSimpleSerialization.h>
|
||||
#include <DataTypes/DataTypeCustomSimpleTextSerialization.h>
|
||||
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
@ -9,7 +9,7 @@ namespace
|
||||
{
|
||||
using namespace DB;
|
||||
|
||||
static String serializeToString(const DataTypeDomainWithSimpleSerialization & domain, const IColumn & column, size_t row_num, const FormatSettings & settings)
|
||||
static String serializeToString(const DataTypeCustomSimpleTextSerialization & domain, const IColumn & column, size_t row_num, const FormatSettings & settings)
|
||||
{
|
||||
WriteBufferFromOwnString buffer;
|
||||
domain.serializeText(column, row_num, buffer, settings);
|
||||
@ -17,7 +17,7 @@ static String serializeToString(const DataTypeDomainWithSimpleSerialization & do
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
static void deserializeFromString(const DataTypeDomainWithSimpleSerialization & domain, IColumn & column, const String & s, const FormatSettings & settings)
|
||||
static void deserializeFromString(const DataTypeCustomSimpleTextSerialization & domain, IColumn & column, const String & s, const FormatSettings & settings)
|
||||
{
|
||||
ReadBufferFromString istr(s);
|
||||
domain.deserializeText(column, istr, settings);
|
||||
@ -28,59 +28,59 @@ static void deserializeFromString(const DataTypeDomainWithSimpleSerialization &
|
||||
namespace DB
|
||||
{
|
||||
|
||||
DataTypeDomainWithSimpleSerialization::~DataTypeDomainWithSimpleSerialization()
|
||||
DataTypeCustomSimpleTextSerialization::~DataTypeCustomSimpleTextSerialization()
|
||||
{
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
writeEscapedString(serializeToString(*this, column, row_num, settings), ostr);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
String str;
|
||||
readEscapedString(str, istr);
|
||||
deserializeFromString(*this, column, str, settings);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
writeQuotedString(serializeToString(*this, column, row_num, settings), ostr);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
String str;
|
||||
readQuotedString(str, istr);
|
||||
deserializeFromString(*this, column, str, settings);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
writeCSVString(serializeToString(*this, column, row_num, settings), ostr);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
String str;
|
||||
readCSVString(str, istr, settings.csv);
|
||||
deserializeFromString(*this, column, str, settings);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
writeJSONString(serializeToString(*this, column, row_num, settings), ostr, settings);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
String str;
|
||||
readJSONString(str, istr);
|
||||
deserializeFromString(*this, column, str, settings);
|
||||
}
|
||||
|
||||
void DataTypeDomainWithSimpleSerialization::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
void DataTypeCustomSimpleTextSerialization::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
writeXMLString(serializeToString(*this, column, row_num, settings), ostr);
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <DataTypes/IDataTypeDomain.h>
|
||||
#include <DataTypes/DataTypeCustom.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,12 +10,12 @@ class WriteBuffer;
|
||||
struct FormatSettings;
|
||||
class IColumn;
|
||||
|
||||
/** Simple DataTypeDomain that uses serializeText/deserializeText
|
||||
/** Simple IDataTypeCustomTextSerialization that uses serializeText/deserializeText
|
||||
* for all serialization and deserialization. */
|
||||
class DataTypeDomainWithSimpleSerialization : public IDataTypeDomain
|
||||
class DataTypeCustomSimpleTextSerialization : public IDataTypeCustomTextSerialization
|
||||
{
|
||||
public:
|
||||
virtual ~DataTypeDomainWithSimpleSerialization() override;
|
||||
virtual ~DataTypeCustomSimpleTextSerialization() override;
|
||||
|
||||
// Methods that subclasses must override in order to get full serialization/deserialization support.
|
||||
virtual void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override = 0;
|
@ -1,5 +1,5 @@
|
||||
#include <DataTypes/DataTypeFactory.h>
|
||||
#include <DataTypes/IDataTypeDomain.h>
|
||||
#include <DataTypes/DataTypeCustom.h>
|
||||
#include <Parsers/parseQuery.h>
|
||||
#include <Parsers/ParserCreateQuery.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
@ -115,19 +115,23 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator
|
||||
}, case_sensitiveness);
|
||||
}
|
||||
|
||||
void DataTypeFactory::registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness)
|
||||
void DataTypeFactory::registerDataTypeCustom(const String & family_name, CreatorWithCustom creator, CaseSensitiveness case_sensitiveness)
|
||||
{
|
||||
all_domains.reserve(all_domains.size() + 1);
|
||||
|
||||
auto data_type = get(type_name);
|
||||
setDataTypeDomain(*data_type, *domain);
|
||||
|
||||
registerDataType(domain->getName(), [data_type](const ASTPtr & /*ast*/)
|
||||
registerDataType(family_name, [creator](const ASTPtr & ast)
|
||||
{
|
||||
return data_type;
|
||||
}, case_sensitiveness);
|
||||
auto res = creator(ast);
|
||||
res.first->setCustomization(std::move(res.second));
|
||||
|
||||
all_domains.emplace_back(std::move(domain));
|
||||
return res.first;
|
||||
}, case_sensitiveness);
|
||||
}
|
||||
|
||||
void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness)
|
||||
{
|
||||
registerDataTypeCustom(name, [creator](const ASTPtr & /*ast*/)
|
||||
{
|
||||
return creator();
|
||||
}, case_sensitiveness);
|
||||
}
|
||||
|
||||
const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const
|
||||
@ -153,11 +157,6 @@ const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String
|
||||
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
|
||||
}
|
||||
|
||||
void DataTypeFactory::setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain)
|
||||
{
|
||||
data_type.setDomain(&domain);
|
||||
}
|
||||
|
||||
void registerDataTypeNumbers(DataTypeFactory & factory);
|
||||
void registerDataTypeDecimal(DataTypeFactory & factory);
|
||||
void registerDataTypeDate(DataTypeFactory & factory);
|
||||
@ -175,6 +174,7 @@ void registerDataTypeNested(DataTypeFactory & factory);
|
||||
void registerDataTypeInterval(DataTypeFactory & factory);
|
||||
void registerDataTypeLowCardinality(DataTypeFactory & factory);
|
||||
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory);
|
||||
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory);
|
||||
|
||||
|
||||
DataTypeFactory::DataTypeFactory()
|
||||
@ -196,6 +196,7 @@ DataTypeFactory::DataTypeFactory()
|
||||
registerDataTypeInterval(*this);
|
||||
registerDataTypeLowCardinality(*this);
|
||||
registerDataTypeDomainIPv4AndIPv6(*this);
|
||||
registerDataTypeDomainSimpleAggregateFunction(*this);
|
||||
}
|
||||
|
||||
DataTypeFactory::~DataTypeFactory()
|
||||
|
@ -17,9 +17,6 @@ namespace DB
|
||||
class IDataType;
|
||||
using DataTypePtr = std::shared_ptr<const IDataType>;
|
||||
|
||||
class IDataTypeDomain;
|
||||
using DataTypeDomainPtr = std::unique_ptr<const IDataTypeDomain>;
|
||||
|
||||
|
||||
/** Creates a data type by name of data type family and parameters.
|
||||
*/
|
||||
@ -28,6 +25,8 @@ class DataTypeFactory final : public ext::singleton<DataTypeFactory>, public IFa
|
||||
private:
|
||||
using SimpleCreator = std::function<DataTypePtr()>;
|
||||
using DataTypesDictionary = std::unordered_map<String, Creator>;
|
||||
using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>;
|
||||
using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>;
|
||||
|
||||
public:
|
||||
DataTypePtr get(const String & full_name) const;
|
||||
@ -40,11 +39,13 @@ public:
|
||||
/// Register a simple data type, that have no parameters.
|
||||
void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
|
||||
// Register a domain - a refinement of existing type.
|
||||
void registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
/// Register a customized type family
|
||||
void registerDataTypeCustom(const String & family_name, CreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
|
||||
/// Register a simple customized data type
|
||||
void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
|
||||
private:
|
||||
static void setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain);
|
||||
const Creator& findCreatorByName(const String & family_name) const;
|
||||
|
||||
private:
|
||||
@ -53,9 +54,6 @@ private:
|
||||
/// Case insensitive data types will be additionally added here with lowercased name.
|
||||
DataTypesDictionary case_insensitive_data_types;
|
||||
|
||||
// All domains are owned by factory and shared amongst DataType instances.
|
||||
std::vector<DataTypeDomainPtr> all_domains;
|
||||
|
||||
DataTypeFactory();
|
||||
~DataTypeFactory() override;
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
#include <Formats/ProtobufWriter.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/readFloatText.h>
|
||||
#include <IO/readDecimalText.h>
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Interpreters/Context.h>
|
||||
@ -52,10 +52,22 @@ void DataTypeDecimal<T>::serializeText(const IColumn & column, size_t row_num, W
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DataTypeDecimal<T>::readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale)
|
||||
bool DataTypeDecimal<T>::tryReadText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale)
|
||||
{
|
||||
UInt32 unread_scale = scale;
|
||||
readDecimalText(istr, x, precision, unread_scale);
|
||||
bool done = tryReadDecimalText(istr, x, precision, unread_scale);
|
||||
x *= getScaleMultiplier(unread_scale);
|
||||
return done;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DataTypeDecimal<T>::readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale, bool csv)
|
||||
{
|
||||
UInt32 unread_scale = scale;
|
||||
if (csv)
|
||||
readCSVDecimalText(istr, x, precision, unread_scale);
|
||||
else
|
||||
readDecimalText(istr, x, precision, unread_scale);
|
||||
x *= getScaleMultiplier(unread_scale);
|
||||
}
|
||||
|
||||
@ -67,6 +79,13 @@ void DataTypeDecimal<T>::deserializeText(IColumn & column, ReadBuffer & istr, co
|
||||
static_cast<ColumnType &>(column).getData().push_back(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DataTypeDecimal<T>::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings &) const
|
||||
{
|
||||
T x;
|
||||
readText(x, istr, true);
|
||||
static_cast<ColumnType &>(column).getData().push_back(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T DataTypeDecimal<T>::parseFromString(const String & str) const
|
||||
|
@ -1,4 +1,6 @@
|
||||
#pragma once
|
||||
#include <cmath>
|
||||
|
||||
#include <common/arithmeticOverflow.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Columns/ColumnDecimal.h>
|
||||
@ -91,6 +93,7 @@ public:
|
||||
|
||||
void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override;
|
||||
void deserializeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override;
|
||||
void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override;
|
||||
|
||||
void serializeBinary(const Field & field, WriteBuffer & ostr) const override;
|
||||
void serializeBinary(const IColumn & column, size_t row_num, WriteBuffer & ostr) const override;
|
||||
@ -175,8 +178,9 @@ public:
|
||||
|
||||
T parseFromString(const String & str) const;
|
||||
|
||||
void readText(T & x, ReadBuffer & istr) const { readText(x, istr, precision, scale); }
|
||||
static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale);
|
||||
void readText(T & x, ReadBuffer & istr, bool csv = false) const { readText(x, istr, precision, scale, csv); }
|
||||
static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale, bool csv = false);
|
||||
static bool tryReadText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale);
|
||||
static T getScaleMultiplier(UInt32 scale);
|
||||
|
||||
private:
|
||||
@ -318,7 +322,11 @@ convertToDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
|
||||
using FromFieldType = typename FromDataType::FieldType;
|
||||
|
||||
if constexpr (std::is_floating_point_v<FromFieldType>)
|
||||
{
|
||||
if (!std::isfinite(value))
|
||||
throw Exception("Decimal convert overflow. Cannot convert infinity or NaN to decimal", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
return value * ToDataType::getScaleMultiplier(scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr (std::is_same_v<FromFieldType, UInt64>)
|
||||
|
@ -9,7 +9,7 @@
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <DataTypes/IDataType.h>
|
||||
#include <DataTypes/IDataTypeDomain.h>
|
||||
#include <DataTypes/DataTypeCustom.h>
|
||||
#include <DataTypes/NestedUtils.h>
|
||||
|
||||
|
||||
@ -23,8 +23,7 @@ namespace ErrorCodes
|
||||
extern const int DATA_TYPE_CANNOT_BE_PROMOTED;
|
||||
}
|
||||
|
||||
IDataType::IDataType()
|
||||
: domain(nullptr)
|
||||
IDataType::IDataType() : custom_name(nullptr), custom_text_serialization(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
@ -34,9 +33,9 @@ IDataType::~IDataType()
|
||||
|
||||
String IDataType::getName() const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_name)
|
||||
{
|
||||
return domain->getName();
|
||||
return custom_name->getName();
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -142,9 +141,9 @@ void IDataType::insertDefaultInto(IColumn & column) const
|
||||
|
||||
void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->serializeTextEscaped(column, row_num, ostr, settings);
|
||||
custom_text_serialization->serializeTextEscaped(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -154,9 +153,9 @@ void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, W
|
||||
|
||||
void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->deserializeTextEscaped(column, istr, settings);
|
||||
custom_text_serialization->deserializeTextEscaped(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -166,9 +165,9 @@ void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, co
|
||||
|
||||
void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->serializeTextQuoted(column, row_num, ostr, settings);
|
||||
custom_text_serialization->serializeTextQuoted(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -178,9 +177,9 @@ void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, Wr
|
||||
|
||||
void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->deserializeTextQuoted(column, istr, settings);
|
||||
custom_text_serialization->deserializeTextQuoted(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -190,9 +189,9 @@ void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, con
|
||||
|
||||
void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->serializeTextCSV(column, row_num, ostr, settings);
|
||||
custom_text_serialization->serializeTextCSV(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -202,9 +201,9 @@ void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, Write
|
||||
|
||||
void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->deserializeTextCSV(column, istr, settings);
|
||||
custom_text_serialization->deserializeTextCSV(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -214,9 +213,9 @@ void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const
|
||||
|
||||
void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->serializeText(column, row_num, ostr, settings);
|
||||
custom_text_serialization->serializeText(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -226,9 +225,9 @@ void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuf
|
||||
|
||||
void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->serializeTextJSON(column, row_num, ostr, settings);
|
||||
custom_text_serialization->serializeTextJSON(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -238,9 +237,9 @@ void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, Writ
|
||||
|
||||
void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->deserializeTextJSON(column, istr, settings);
|
||||
custom_text_serialization->deserializeTextJSON(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -250,9 +249,9 @@ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const
|
||||
|
||||
void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (custom_text_serialization)
|
||||
{
|
||||
domain->serializeTextXML(column, row_num, ostr, settings);
|
||||
custom_text_serialization->serializeTextXML(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -260,13 +259,14 @@ void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, Write
|
||||
}
|
||||
}
|
||||
|
||||
void IDataType::setDomain(const IDataTypeDomain* const new_domain) const
|
||||
void IDataType::setCustomization(DataTypeCustomDescPtr custom_desc_) const
|
||||
{
|
||||
if (domain != nullptr)
|
||||
{
|
||||
throw Exception("Type " + getName() + " already has a domain.", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
domain = new_domain;
|
||||
/// replace only if not null
|
||||
if (custom_desc_->name)
|
||||
custom_name = std::move(custom_desc_->name);
|
||||
|
||||
if (custom_desc_->text_serialization)
|
||||
custom_text_serialization = std::move(custom_desc_->text_serialization);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <Common/COW.h>
|
||||
#include <boost/noncopyable.hpp>
|
||||
#include <Core/Field.h>
|
||||
#include <DataTypes/DataTypeCustom.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -12,7 +13,6 @@ namespace DB
|
||||
class ReadBuffer;
|
||||
class WriteBuffer;
|
||||
|
||||
class IDataTypeDomain;
|
||||
class IDataType;
|
||||
struct FormatSettings;
|
||||
|
||||
@ -459,18 +459,19 @@ public:
|
||||
|
||||
private:
|
||||
friend class DataTypeFactory;
|
||||
/** Sets domain on existing DataType, can be considered as second phase
|
||||
* of construction explicitly done by DataTypeFactory.
|
||||
* Will throw an exception if domain is already set.
|
||||
/** Customize this DataType
|
||||
*/
|
||||
void setDomain(const IDataTypeDomain* newDomain) const;
|
||||
void setCustomization(DataTypeCustomDescPtr custom_desc_) const;
|
||||
|
||||
private:
|
||||
/** This is mutable to allow setting domain on `const IDataType` post construction,
|
||||
* simplifying creation of domains for all types, without them even knowing
|
||||
* of domain existence.
|
||||
/** This is mutable to allow setting custom name and serialization on `const IDataType` post construction.
|
||||
*/
|
||||
mutable IDataTypeDomain const* domain;
|
||||
mutable DataTypeCustomNamePtr custom_name;
|
||||
mutable DataTypeCustomTextSerializationPtr custom_text_serialization;
|
||||
|
||||
public:
|
||||
const IDataTypeCustomName * getCustomName() const { return custom_name.get(); }
|
||||
const IDataTypeCustomTextSerialization * getCustomTextSerialization() const { return custom_text_serialization.get(); }
|
||||
};
|
||||
|
||||
|
||||
@ -573,6 +574,13 @@ inline bool isInteger(const T & data_type)
|
||||
return which.isInt() || which.isUInt();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool isFloat(const T & data_type)
|
||||
{
|
||||
WhichDataType which(data_type);
|
||||
return which.isFloat();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool isNumber(const T & data_type)
|
||||
{
|
||||
|
@ -64,23 +64,25 @@ void registerInputFormatRowBinary(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<BinaryRowInputStream>(buf, sample, false, false),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
|
||||
factory.registerInputFormat("RowBinaryWithNamesAndTypes", [](
|
||||
ReadBuffer & buf,
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
size_t max_block_size,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<BinaryRowInputStream>(buf, sample, true, true),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -27,8 +27,9 @@ BlockInputStreamFromRowInputStream::BlockInputStreamFromRowInputStream(
|
||||
const RowInputStreamPtr & row_input_,
|
||||
const Block & sample_,
|
||||
UInt64 max_block_size_,
|
||||
UInt64 rows_portion_size_,
|
||||
const FormatSettings & settings)
|
||||
: row_input(row_input_), sample(sample_), max_block_size(max_block_size_),
|
||||
: row_input(row_input_), sample(sample_), max_block_size(max_block_size_), rows_portion_size(rows_portion_size_),
|
||||
allow_errors_num(settings.input_allow_errors_num), allow_errors_ratio(settings.input_allow_errors_ratio)
|
||||
{
|
||||
}
|
||||
@ -57,8 +58,15 @@ Block BlockInputStreamFromRowInputStream::readImpl()
|
||||
|
||||
try
|
||||
{
|
||||
for (size_t rows = 0; rows < max_block_size; ++rows)
|
||||
for (size_t rows = 0, batch = 0; rows < max_block_size; ++rows, ++batch)
|
||||
{
|
||||
if (rows_portion_size && batch == rows_portion_size)
|
||||
{
|
||||
batch = 0;
|
||||
if (!checkTimeLimit() || isCancelled())
|
||||
break;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
++total_rows;
|
||||
|
@ -17,11 +17,13 @@ namespace DB
|
||||
class BlockInputStreamFromRowInputStream : public IBlockInputStream
|
||||
{
|
||||
public:
|
||||
/** sample_ - block with zero rows, that structure describes how to interpret values */
|
||||
/// |sample| is a block with zero rows, that structure describes how to interpret values
|
||||
/// |rows_portion_size| is a number of rows to read before break and check limits
|
||||
BlockInputStreamFromRowInputStream(
|
||||
const RowInputStreamPtr & row_input_,
|
||||
const Block & sample_,
|
||||
UInt64 max_block_size_,
|
||||
UInt64 rows_portion_size_,
|
||||
const FormatSettings & settings);
|
||||
|
||||
void readPrefix() override { row_input->readPrefix(); }
|
||||
@ -42,6 +44,7 @@ private:
|
||||
RowInputStreamPtr row_input;
|
||||
Block sample;
|
||||
UInt64 max_block_size;
|
||||
UInt64 rows_portion_size;
|
||||
BlockMissingValues block_missing_values;
|
||||
|
||||
UInt64 allow_errors_num;
|
||||
@ -50,5 +53,4 @@ private:
|
||||
size_t total_rows = 0;
|
||||
size_t num_errors = 0;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -478,11 +478,12 @@ void registerInputFormatCSV(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<CSVRowInputStream>(buf, sample, with_names, settings),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -302,12 +302,18 @@ void registerInputFormatCapnProto(FormatFactory & factory)
|
||||
{
|
||||
factory.registerInputFormat(
|
||||
"CapnProto",
|
||||
[](ReadBuffer & buf, const Block & sample, const Context & context, UInt64 max_block_size, const FormatSettings & settings)
|
||||
[](ReadBuffer & buf,
|
||||
const Block & sample,
|
||||
const Context & context,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<CapnProtoRowInputStream>(buf, sample, FormatSchemaInfo(context, "CapnProto")),
|
||||
sample,
|
||||
max_block_size,
|
||||
rows_portion_size,
|
||||
settings);
|
||||
});
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ const FormatFactory::Creators & FormatFactory::getCreators(const String & name)
|
||||
}
|
||||
|
||||
|
||||
BlockInputStreamPtr FormatFactory::getInput(const String & name, ReadBuffer & buf, const Block & sample, const Context & context, UInt64 max_block_size) const
|
||||
BlockInputStreamPtr FormatFactory::getInput(const String & name, ReadBuffer & buf, const Block & sample, const Context & context, UInt64 max_block_size, UInt64 rows_portion_size) const
|
||||
{
|
||||
const auto & input_getter = getCreators(name).first;
|
||||
if (!input_getter)
|
||||
@ -47,7 +47,7 @@ BlockInputStreamPtr FormatFactory::getInput(const String & name, ReadBuffer & bu
|
||||
format_settings.input_allow_errors_num = settings.input_format_allow_errors_num;
|
||||
format_settings.input_allow_errors_ratio = settings.input_format_allow_errors_ratio;
|
||||
|
||||
return input_getter(buf, sample, context, max_block_size, format_settings);
|
||||
return input_getter(buf, sample, context, max_block_size, rows_portion_size, format_settings);
|
||||
}
|
||||
|
||||
|
||||
|
@ -35,6 +35,7 @@ private:
|
||||
const Block & sample,
|
||||
const Context & context,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)>;
|
||||
|
||||
using OutputCreator = std::function<BlockOutputStreamPtr(
|
||||
@ -49,7 +50,7 @@ private:
|
||||
|
||||
public:
|
||||
BlockInputStreamPtr getInput(const String & name, ReadBuffer & buf,
|
||||
const Block & sample, const Context & context, UInt64 max_block_size) const;
|
||||
const Block & sample, const Context & context, UInt64 max_block_size, UInt64 rows_portion_size = 0) const;
|
||||
|
||||
BlockOutputStreamPtr getOutput(const String & name, WriteBuffer & buf,
|
||||
const Block & sample, const Context & context) const;
|
||||
|
@ -259,11 +259,12 @@ void registerInputFormatJSONEachRow(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<JSONEachRowRowInputStream>(buf, sample, settings),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,8 @@ void registerInputFormatNative(FormatFactory & factory)
|
||||
ReadBuffer & buf,
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
size_t,
|
||||
UInt64 /* max_block_size */,
|
||||
UInt64 /* min_read_rows */,
|
||||
const FormatSettings &)
|
||||
{
|
||||
return std::make_shared<NativeBlockInputStream>(buf, sample, 0);
|
||||
|
@ -475,7 +475,8 @@ void registerInputFormatParquet(FormatFactory & factory)
|
||||
[](ReadBuffer & buf,
|
||||
const Block & sample,
|
||||
const Context & context,
|
||||
size_t /*max_block_size */,
|
||||
UInt64 /* max_block_size */,
|
||||
UInt64 /* rows_portion_size */,
|
||||
const FormatSettings & /* settings */) { return std::make_shared<ParquetBlockInputStream>(buf, sample, context); });
|
||||
}
|
||||
|
||||
|
@ -72,11 +72,12 @@ void registerInputFormatProtobuf(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context & context,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<ProtobufRowInputStream>(buf, sample, FormatSchemaInfo(context, "Protobuf")),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -198,11 +198,12 @@ void registerInputFormatTSKV(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<TSKVRowInputStream>(buf, sample, settings),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -456,11 +456,12 @@ void registerInputFormatTabSeparated(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<TabSeparatedRowInputStream>(buf, sample, false, false, settings),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
|
||||
@ -471,11 +472,12 @@ void registerInputFormatTabSeparated(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<TabSeparatedRowInputStream>(buf, sample, true, false, settings),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
|
||||
@ -486,11 +488,12 @@ void registerInputFormatTabSeparated(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context &,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<TabSeparatedRowInputStream>(buf, sample, true, true, settings),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -155,11 +155,12 @@ void registerInputFormatValues(FormatFactory & factory)
|
||||
const Block & sample,
|
||||
const Context & context,
|
||||
UInt64 max_block_size,
|
||||
UInt64 rows_portion_size,
|
||||
const FormatSettings & settings)
|
||||
{
|
||||
return std::make_shared<BlockInputStreamFromRowInputStream>(
|
||||
std::make_shared<ValuesRowInputStream>(buf, sample, context, settings),
|
||||
sample, max_block_size, settings);
|
||||
sample, max_block_size, rows_portion_size, settings);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ try
|
||||
FormatSettings format_settings;
|
||||
|
||||
RowInputStreamPtr row_input = std::make_shared<TabSeparatedRowInputStream>(in_buf, sample, false, false, format_settings);
|
||||
BlockInputStreamFromRowInputStream block_input(row_input, sample, DEFAULT_INSERT_BLOCK_SIZE, format_settings);
|
||||
BlockInputStreamFromRowInputStream block_input(row_input, sample, DEFAULT_INSERT_BLOCK_SIZE, 0, format_settings);
|
||||
RowOutputStreamPtr row_output = std::make_shared<TabSeparatedRowOutputStream>(out_buf, sample, false, false, format_settings);
|
||||
BlockOutputStreamFromRowOutputStream block_output(row_output, sample);
|
||||
|
||||
|
@ -42,7 +42,7 @@ try
|
||||
RowInputStreamPtr row_input = std::make_shared<TabSeparatedRowInputStream>(in_buf, sample, false, false, format_settings);
|
||||
RowOutputStreamPtr row_output = std::make_shared<TabSeparatedRowOutputStream>(out_buf, sample, false, false, format_settings);
|
||||
|
||||
BlockInputStreamFromRowInputStream block_input(row_input, sample, DEFAULT_INSERT_BLOCK_SIZE, format_settings);
|
||||
BlockInputStreamFromRowInputStream block_input(row_input, sample, DEFAULT_INSERT_BLOCK_SIZE, 0, format_settings);
|
||||
BlockOutputStreamFromRowOutputStream block_output(row_output, sample);
|
||||
|
||||
copyData(block_input, block_output);
|
||||
|
@ -60,12 +60,16 @@ if(USE_BASE64)
|
||||
target_include_directories(clickhouse_functions SYSTEM PRIVATE ${BASE64_INCLUDE_DIR})
|
||||
endif()
|
||||
|
||||
if (USE_XXHASH)
|
||||
if(USE_XXHASH)
|
||||
target_link_libraries(clickhouse_functions PRIVATE ${XXHASH_LIBRARY})
|
||||
target_include_directories(clickhouse_functions SYSTEM PRIVATE ${XXHASH_INCLUDE_DIR})
|
||||
endif()
|
||||
|
||||
if (USE_HYPERSCAN)
|
||||
target_link_libraries (clickhouse_functions PRIVATE ${HYPERSCAN_LIBRARY})
|
||||
target_include_directories (clickhouse_functions SYSTEM PRIVATE ${HYPERSCAN_INCLUDE_DIR})
|
||||
endif ()
|
||||
if(USE_HYPERSCAN)
|
||||
target_link_libraries(clickhouse_functions PRIVATE ${HYPERSCAN_LIBRARY})
|
||||
target_include_directories(clickhouse_functions SYSTEM PRIVATE ${HYPERSCAN_INCLUDE_DIR})
|
||||
endif()
|
||||
|
||||
if(USE_SIMDJSON)
|
||||
target_link_libraries(clickhouse_functions PRIVATE ${SIMDJSON_LIBRARY})
|
||||
endif()
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
@ -13,6 +14,7 @@ namespace DB
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
const ColumnConst * checkAndGetColumnConstStringOrFixedString(const IColumn * column)
|
||||
@ -99,4 +101,21 @@ Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & ar
|
||||
return createBlockWithNestedColumnsImpl(block, args_set);
|
||||
}
|
||||
|
||||
void validateArgumentType(const IFunction & func, const DataTypes & arguments,
|
||||
size_t argument_index, bool (* validator_func)(const IDataType &),
|
||||
const char * expected_type_description)
|
||||
{
|
||||
if (arguments.size() <= argument_index)
|
||||
throw Exception("Incorrect number of arguments of function " + func.getName(),
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
const auto & argument = arguments[argument_index];
|
||||
if (validator_func(*argument) == false)
|
||||
throw Exception("Illegal type " + argument->getName() +
|
||||
" of " + std::to_string(argument_index) +
|
||||
" argument of function " + func.getName() +
|
||||
" expected " + expected_type_description,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -12,6 +12,8 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class IFunction;
|
||||
|
||||
/// Methods, that helps dispatching over real column types.
|
||||
|
||||
template <typename Type>
|
||||
@ -64,7 +66,6 @@ bool checkColumnConst(const IColumn * column)
|
||||
return checkAndGetColumnConst<Type>(column);
|
||||
}
|
||||
|
||||
|
||||
/// Returns non-nullptr if column is ColumnConst with ColumnString or ColumnFixedString inside.
|
||||
const ColumnConst * checkAndGetColumnConstStringOrFixedString(const IColumn * column);
|
||||
|
||||
@ -94,4 +95,10 @@ Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & ar
|
||||
/// Similar function as above. Additionally transform the result type if needed.
|
||||
Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & args, size_t result);
|
||||
|
||||
/// Checks argument type at specified index with predicate.
|
||||
/// throws if there is no argument at specified index or if predicate returns false.
|
||||
void validateArgumentType(const IFunction & func, const DataTypes & arguments,
|
||||
size_t argument_index, bool (* validator_func)(const IDataType &),
|
||||
const char * expected_type_description);
|
||||
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ using StoragePtr = std::shared_ptr<IStorage>;
|
||||
class Join;
|
||||
using JoinPtr = std::shared_ptr<Join>;
|
||||
|
||||
class FunctionJoinGet final : public IFunction, public std::enable_shared_from_this<FunctionJoinGet>
|
||||
class FunctionJoinGet final : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "joinGet";
|
||||
|
@ -406,7 +406,12 @@ private:
|
||||
{
|
||||
const ColumnAggregateFunction * columns[2];
|
||||
for (size_t i = 0; i < 2; ++i)
|
||||
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
|
||||
{
|
||||
if (auto argument_column_const = typeid_cast<const ColumnConst *>(block.getByPosition(arguments[i]).column.get()))
|
||||
columns[i] = typeid_cast<const ColumnAggregateFunction *>(argument_column_const->getDataColumnPtr().get());
|
||||
else
|
||||
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_rows_count; ++i)
|
||||
{
|
||||
@ -511,7 +516,12 @@ private:
|
||||
{
|
||||
const ColumnAggregateFunction * columns[2];
|
||||
for (size_t i = 0; i < 2; ++i)
|
||||
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
|
||||
{
|
||||
if (auto argument_column_const = typeid_cast<const ColumnConst *>(block.getByPosition(arguments[i]).column.get()))
|
||||
columns[i] = typeid_cast<const ColumnAggregateFunction *>(argument_column_const->getDataColumnPtr().get());
|
||||
else
|
||||
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
|
||||
}
|
||||
|
||||
auto col_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction());
|
||||
|
||||
|
@ -28,6 +28,8 @@ void registerFunctionsCoding(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionBitmaskToArray>();
|
||||
factory.registerFunction<FunctionToIPv4>();
|
||||
factory.registerFunction<FunctionToIPv6>();
|
||||
factory.registerFunction<FunctionIPv6CIDRToRange>();
|
||||
factory.registerFunction<FunctionIPv4CIDRToRange>();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -12,11 +12,13 @@
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeUUID.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnConst.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
|
||||
@ -1450,4 +1452,208 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class FunctionIPv6CIDRToRange : public IFunction
|
||||
{
|
||||
private:
|
||||
/// TODO Inefficient.
|
||||
/// NOTE IPv6 is stored in memory in big endian format that makes some difficulties.
|
||||
static void applyCIDRMask(const UInt8 * __restrict src, UInt8 * __restrict dst_lower, UInt8 * __restrict dst_upper, UInt8 bits_to_keep)
|
||||
{
|
||||
UInt8 mask[16]{};
|
||||
|
||||
UInt8 bytes_to_keep = bits_to_keep / 8;
|
||||
UInt8 bits_to_keep_in_last_byte = bits_to_keep % 8;
|
||||
|
||||
for (size_t i = 0; i < bits_to_keep / 8; ++i)
|
||||
mask[i] = 0xFFU;
|
||||
|
||||
if (bits_to_keep_in_last_byte)
|
||||
mask[bytes_to_keep] = 0xFFU << (8 - bits_to_keep_in_last_byte);
|
||||
|
||||
for (size_t i = 0; i < 16; ++i)
|
||||
{
|
||||
dst_lower[i] = src[i] & mask[i];
|
||||
dst_upper[i] = dst_lower[i] | ~mask[i];
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto name = "IPv6CIDRToRange";
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionIPv6CIDRToRange>(); }
|
||||
|
||||
String getName() const override { return name; }
|
||||
size_t getNumberOfArguments() const override { return 2; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
const auto first_argument = checkAndGetDataType<DataTypeFixedString>(arguments[0].get());
|
||||
if (!first_argument || first_argument->getN() != IPV6_BINARY_LENGTH)
|
||||
throw Exception("Illegal type " + arguments[0]->getName() +
|
||||
" of first argument of function " + getName() +
|
||||
", expected FixedString(" + toString(IPV6_BINARY_LENGTH) + ")",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
const DataTypePtr & second_argument = arguments[1];
|
||||
if (!isUInt8(second_argument))
|
||||
throw Exception{"Illegal type " + second_argument->getName()
|
||||
+ " of second argument of function " + getName()
|
||||
+ ", expected numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
DataTypePtr element = DataTypeFactory::instance().get("IPv6");
|
||||
return std::make_shared<DataTypeTuple>(DataTypes{element, element});
|
||||
}
|
||||
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
|
||||
{
|
||||
const auto & col_type_name_ip = block.getByPosition(arguments[0]);
|
||||
const ColumnPtr & column_ip = col_type_name_ip.column;
|
||||
|
||||
const auto col_ip_in = checkAndGetColumn<ColumnFixedString>(column_ip.get());
|
||||
|
||||
if (!col_ip_in)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
|
||||
+ " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
if (col_ip_in->getN() != IPV6_BINARY_LENGTH)
|
||||
throw Exception("Illegal type " + col_type_name_ip.type->getName() +
|
||||
" of column " + col_ip_in->getName() +
|
||||
" argument of function " + getName() +
|
||||
", expected FixedString(" + toString(IPV6_BINARY_LENGTH) + ")",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
const auto & col_type_name_cidr = block.getByPosition(arguments[1]);
|
||||
const ColumnPtr & column_cidr = col_type_name_cidr.column;
|
||||
|
||||
const auto col_const_cidr_in = checkAndGetColumnConst<ColumnUInt8>(column_cidr.get());
|
||||
const auto col_cidr_in = checkAndGetColumn<ColumnUInt8>(column_cidr.get());
|
||||
|
||||
if (!col_const_cidr_in && !col_cidr_in)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[1]).column->getName()
|
||||
+ " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
const auto & vec_in = col_ip_in->getChars();
|
||||
|
||||
auto col_res_lower_range = ColumnFixedString::create(IPV6_BINARY_LENGTH);
|
||||
auto col_res_upper_range = ColumnFixedString::create(IPV6_BINARY_LENGTH);
|
||||
|
||||
ColumnString::Chars & vec_res_lower_range = col_res_lower_range->getChars();
|
||||
vec_res_lower_range.resize(input_rows_count * IPV6_BINARY_LENGTH);
|
||||
|
||||
ColumnString::Chars & vec_res_upper_range = col_res_upper_range->getChars();
|
||||
vec_res_upper_range.resize(input_rows_count * IPV6_BINARY_LENGTH);
|
||||
|
||||
for (size_t offset = 0; offset < input_rows_count; ++offset)
|
||||
{
|
||||
const size_t offset_ipv6 = offset * IPV6_BINARY_LENGTH;
|
||||
UInt8 cidr = col_const_cidr_in
|
||||
? col_const_cidr_in->getValue<UInt8>()
|
||||
: col_cidr_in->getData()[offset];
|
||||
|
||||
applyCIDRMask(&vec_in[offset_ipv6], &vec_res_lower_range[offset_ipv6], &vec_res_upper_range[offset_ipv6], cidr);
|
||||
}
|
||||
|
||||
block.getByPosition(result).column = ColumnTuple::create(Columns{std::move(col_res_lower_range), std::move(col_res_upper_range)});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class FunctionIPv4CIDRToRange : public IFunction
|
||||
{
|
||||
private:
|
||||
static inline std::pair<UInt32, UInt32> applyCIDRMask(UInt32 src, UInt8 bits_to_keep)
|
||||
{
|
||||
if (bits_to_keep >= 8 * sizeof(UInt32))
|
||||
return { src, src };
|
||||
if (bits_to_keep == 0)
|
||||
return { UInt32(0), UInt32(-1) };
|
||||
|
||||
UInt32 mask = UInt32(-1) << (8 * sizeof(UInt32) - bits_to_keep);
|
||||
UInt32 lower = src & mask;
|
||||
UInt32 upper = lower | ~mask;
|
||||
|
||||
return { lower, upper };
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto name = "IPv4CIDRToRange";
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionIPv4CIDRToRange>(); }
|
||||
|
||||
String getName() const override { return name; }
|
||||
size_t getNumberOfArguments() const override { return 2; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
if (!WhichDataType(arguments[0]).isUInt32())
|
||||
throw Exception("Illegal type " + arguments[0]->getName() +
|
||||
" of first argument of function " + getName() +
|
||||
", expected UInt32",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
|
||||
const DataTypePtr & second_argument = arguments[1];
|
||||
if (!isUInt8(second_argument))
|
||||
throw Exception{"Illegal type " + second_argument->getName()
|
||||
+ " of second argument of function " + getName()
|
||||
+ ", expected numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
DataTypePtr element = DataTypeFactory::instance().get("IPv4");
|
||||
return std::make_shared<DataTypeTuple>(DataTypes{element, element});
|
||||
}
|
||||
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
|
||||
{
|
||||
const auto & col_type_name_ip = block.getByPosition(arguments[0]);
|
||||
const ColumnPtr & column_ip = col_type_name_ip.column;
|
||||
|
||||
const auto col_ip_in = checkAndGetColumn<ColumnUInt32>(column_ip.get());
|
||||
if (!col_ip_in)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
|
||||
+ " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
const auto & col_type_name_cidr = block.getByPosition(arguments[1]);
|
||||
const ColumnPtr & column_cidr = col_type_name_cidr.column;
|
||||
|
||||
const auto col_const_cidr_in = checkAndGetColumnConst<ColumnUInt8>(column_cidr.get());
|
||||
const auto col_cidr_in = checkAndGetColumn<ColumnUInt8>(column_cidr.get());
|
||||
|
||||
if (!col_const_cidr_in && !col_cidr_in)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[1]).column->getName()
|
||||
+ " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
const auto & vec_in = col_ip_in->getData();
|
||||
|
||||
auto col_res_lower_range = ColumnUInt32::create();
|
||||
auto col_res_upper_range = ColumnUInt32::create();
|
||||
|
||||
auto & vec_res_lower_range = col_res_lower_range->getData();
|
||||
vec_res_lower_range.resize(input_rows_count);
|
||||
|
||||
auto & vec_res_upper_range = col_res_upper_range->getData();
|
||||
vec_res_upper_range.resize(input_rows_count);
|
||||
|
||||
for (size_t i = 0; i < input_rows_count; ++i)
|
||||
{
|
||||
UInt8 cidr = col_const_cidr_in
|
||||
? col_const_cidr_in->getValue<UInt8>()
|
||||
: col_cidr_in->getData()[i];
|
||||
|
||||
std::tie(vec_res_lower_range[i], vec_res_upper_range[i]) = applyCIDRMask(vec_in[i], cidr);
|
||||
}
|
||||
|
||||
block.getByPosition(result).column = ColumnTuple::create(Columns{std::move(col_res_lower_range), std::move(col_res_upper_range)});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
@ -66,6 +66,10 @@ void registerFunctionsConversion(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionToDateOrZero>();
|
||||
factory.registerFunction<FunctionToDateTimeOrZero>();
|
||||
|
||||
factory.registerFunction<FunctionToDecimal32OrZero>();
|
||||
factory.registerFunction<FunctionToDecimal64OrZero>();
|
||||
factory.registerFunction<FunctionToDecimal128OrZero>();
|
||||
|
||||
factory.registerFunction<FunctionToUInt8OrNull>();
|
||||
factory.registerFunction<FunctionToUInt16OrNull>();
|
||||
factory.registerFunction<FunctionToUInt32OrNull>();
|
||||
@ -79,6 +83,10 @@ void registerFunctionsConversion(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionToDateOrNull>();
|
||||
factory.registerFunction<FunctionToDateTimeOrNull>();
|
||||
|
||||
factory.registerFunction<FunctionToDecimal32OrNull>();
|
||||
factory.registerFunction<FunctionToDecimal64OrNull>();
|
||||
factory.registerFunction<FunctionToDecimal128OrNull>();
|
||||
|
||||
factory.registerFunction<FunctionParseDateTimeBestEffort>();
|
||||
factory.registerFunction<FunctionParseDateTimeBestEffortOrZero>();
|
||||
factory.registerFunction<FunctionParseDateTimeBestEffortOrNull>();
|
||||
|
@ -543,11 +543,7 @@ struct ConvertThroughParsing
|
||||
|
||||
ReadBufferFromMemory read_buffer(&(*chars)[current_offset], string_size);
|
||||
|
||||
if constexpr (IsDataTypeDecimal<ToDataType>)
|
||||
{
|
||||
ToDataType::readText(vec_to[i], read_buffer, ToDataType::maxPrecision(), vec_to.getScale());
|
||||
}
|
||||
else if constexpr (exception_mode == ConvertFromStringExceptionMode::Throw)
|
||||
if constexpr (exception_mode == ConvertFromStringExceptionMode::Throw)
|
||||
{
|
||||
if constexpr (parsing_mode == ConvertFromStringParsingMode::BestEffort)
|
||||
{
|
||||
@ -557,7 +553,10 @@ struct ConvertThroughParsing
|
||||
}
|
||||
else
|
||||
{
|
||||
parseImpl<ToDataType>(vec_to[i], read_buffer, local_time_zone);
|
||||
if constexpr (IsDataTypeDecimal<ToDataType>)
|
||||
ToDataType::readText(vec_to[i], read_buffer, ToDataType::maxPrecision(), vec_to.getScale());
|
||||
else
|
||||
parseImpl<ToDataType>(vec_to[i], read_buffer, local_time_zone);
|
||||
}
|
||||
|
||||
if (!isAllRead(read_buffer))
|
||||
@ -575,7 +574,12 @@ struct ConvertThroughParsing
|
||||
}
|
||||
else
|
||||
{
|
||||
parsed = tryParseImpl<ToDataType>(vec_to[i], read_buffer, local_time_zone) && isAllRead(read_buffer);
|
||||
if constexpr (IsDataTypeDecimal<ToDataType>)
|
||||
parsed = ToDataType::tryReadText(vec_to[i], read_buffer, ToDataType::maxPrecision(), vec_to.getScale());
|
||||
else
|
||||
parsed = tryParseImpl<ToDataType>(vec_to[i], read_buffer, local_time_zone);
|
||||
|
||||
parsed = parsed && isAllRead(read_buffer);
|
||||
}
|
||||
|
||||
if (!parsed)
|
||||
@ -959,27 +963,37 @@ public:
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
if (arguments.size() != 1 && arguments.size() != 2)
|
||||
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
|
||||
+ toString(arguments.size()) + ", should be 1 or 2. Second argument (time zone) is optional only make sense for DateTime.",
|
||||
if ((arguments.size() != 1 && arguments.size() != 2) || (to_decimal && arguments.size() != 2))
|
||||
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) +
|
||||
", should be 1 or 2. Second argument only make sense for DateTime (time zone, optional) and Decimal (scale).",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (!isStringOrFixedString(arguments[0].type))
|
||||
throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
/// Optional second argument with time zone is supported.
|
||||
|
||||
if (arguments.size() == 2)
|
||||
{
|
||||
if constexpr (!std::is_same_v<ToDataType, DataTypeDateTime>)
|
||||
if constexpr (std::is_same_v<ToDataType, DataTypeDateTime>)
|
||||
{
|
||||
if (!isString(arguments[1].type))
|
||||
throw Exception("Illegal type " + arguments[1].type->getName() + " of 2nd argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
else if constexpr (to_decimal)
|
||||
{
|
||||
if (!isInteger(arguments[1].type))
|
||||
throw Exception("Illegal type " + arguments[1].type->getName() + " of 2nd argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
if (!arguments[1].column)
|
||||
throw Exception("Second argument for function " + getName() + " must be constant", ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
|
||||
+ toString(arguments.size()) + ", should be 1. Second argument makes sense only when converting to DateTime.",
|
||||
+ toString(arguments.size()) + ", should be 1. Second argument makes sense only for DateTime and Decimal.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (!isString(arguments[1].type))
|
||||
throw Exception("Illegal type " + arguments[1].type->getName() + " of 2nd argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
DataTypePtr res;
|
||||
@ -987,7 +1001,19 @@ public:
|
||||
if constexpr (std::is_same_v<ToDataType, DataTypeDateTime>)
|
||||
res = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0));
|
||||
else if constexpr (to_decimal)
|
||||
throw Exception(getName() + " is only implemented for types String and Decimal", ErrorCodes::NOT_IMPLEMENTED);
|
||||
{
|
||||
UInt64 scale = extractToDecimalScale(arguments[1]);
|
||||
|
||||
if constexpr (std::is_same_v<ToDataType, DataTypeDecimal<Decimal32>>)
|
||||
res = createDecimal(9, scale);
|
||||
else if constexpr (std::is_same_v<ToDataType, DataTypeDecimal<Decimal64>>)
|
||||
res = createDecimal(18, scale);
|
||||
else if constexpr (std::is_same_v<ToDataType, DataTypeDecimal<Decimal128>>)
|
||||
res = createDecimal(38, scale);
|
||||
|
||||
if (!res)
|
||||
throw Exception("Someting wrong with toDecimalNNOrZero() or toDecimalNNOrNull()", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
else
|
||||
res = std::make_shared<ToDataType>();
|
||||
|
||||
@ -1001,13 +1027,45 @@ public:
|
||||
{
|
||||
const IDataType * from_type = block.getByPosition(arguments[0]).type.get();
|
||||
|
||||
if (checkAndGetDataType<DataTypeString>(from_type))
|
||||
ConvertThroughParsing<DataTypeString, ToDataType, Name, exception_mode, parsing_mode>::execute(
|
||||
block, arguments, result, input_rows_count);
|
||||
else if (checkAndGetDataType<DataTypeFixedString>(from_type))
|
||||
ConvertThroughParsing<DataTypeFixedString, ToDataType, Name, exception_mode, parsing_mode>::execute(
|
||||
block, arguments, result, input_rows_count);
|
||||
bool ok = true;
|
||||
if constexpr (to_decimal)
|
||||
{
|
||||
if (arguments.size() != 2)
|
||||
throw Exception{"Function " + getName() + " expects 2 arguments for Decimal.", ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
UInt32 scale = extractToDecimalScale(block.getByPosition(arguments[1]));
|
||||
|
||||
if (checkAndGetDataType<DataTypeString>(from_type))
|
||||
{
|
||||
ConvertThroughParsing<DataTypeString, ToDataType, Name, exception_mode, parsing_mode>::execute(
|
||||
block, arguments, result, input_rows_count, scale);
|
||||
}
|
||||
else if (checkAndGetDataType<DataTypeFixedString>(from_type))
|
||||
{
|
||||
ConvertThroughParsing<DataTypeFixedString, ToDataType, Name, exception_mode, parsing_mode>::execute(
|
||||
block, arguments, result, input_rows_count, scale);
|
||||
}
|
||||
else
|
||||
ok = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (checkAndGetDataType<DataTypeString>(from_type))
|
||||
{
|
||||
ConvertThroughParsing<DataTypeString, ToDataType, Name, exception_mode, parsing_mode>::execute(
|
||||
block, arguments, result, input_rows_count);
|
||||
}
|
||||
else if (checkAndGetDataType<DataTypeFixedString>(from_type))
|
||||
{
|
||||
ConvertThroughParsing<DataTypeFixedString, ToDataType, Name, exception_mode, parsing_mode>::execute(
|
||||
block, arguments, result, input_rows_count);
|
||||
}
|
||||
else
|
||||
ok = false;
|
||||
|
||||
}
|
||||
|
||||
if (!ok)
|
||||
throw Exception("Illegal type " + block.getByPosition(arguments[0]).type->getName() + " of argument of function " + getName()
|
||||
+ ". Only String or FixedString argument is accepted for try-conversion function."
|
||||
+ " For other arguments, use function without 'orZero' or 'orNull'.",
|
||||
@ -1358,6 +1416,9 @@ struct NameToFloat32OrZero { static constexpr auto name = "toFloat32OrZero"; };
|
||||
struct NameToFloat64OrZero { static constexpr auto name = "toFloat64OrZero"; };
|
||||
struct NameToDateOrZero { static constexpr auto name = "toDateOrZero"; };
|
||||
struct NameToDateTimeOrZero { static constexpr auto name = "toDateTimeOrZero"; };
|
||||
struct NameToDecimal32OrZero { static constexpr auto name = "toDecimal32OrZero"; };
|
||||
struct NameToDecimal64OrZero { static constexpr auto name = "toDecimal64OrZero"; };
|
||||
struct NameToDecimal128OrZero { static constexpr auto name = "toDecimal128OrZero"; };
|
||||
|
||||
using FunctionToUInt8OrZero = FunctionConvertFromString<DataTypeUInt8, NameToUInt8OrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
using FunctionToUInt16OrZero = FunctionConvertFromString<DataTypeUInt16, NameToUInt16OrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
@ -1371,6 +1432,9 @@ using FunctionToFloat32OrZero = FunctionConvertFromString<DataTypeFloat32, NameT
|
||||
using FunctionToFloat64OrZero = FunctionConvertFromString<DataTypeFloat64, NameToFloat64OrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
using FunctionToDateOrZero = FunctionConvertFromString<DataTypeDate, NameToDateOrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
using FunctionToDateTimeOrZero = FunctionConvertFromString<DataTypeDateTime, NameToDateTimeOrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
using FunctionToDecimal32OrZero = FunctionConvertFromString<DataTypeDecimal<Decimal32>, NameToDecimal32OrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
using FunctionToDecimal64OrZero = FunctionConvertFromString<DataTypeDecimal<Decimal64>, NameToDecimal64OrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
using FunctionToDecimal128OrZero = FunctionConvertFromString<DataTypeDecimal<Decimal128>, NameToDecimal128OrZero, ConvertFromStringExceptionMode::Zero>;
|
||||
|
||||
struct NameToUInt8OrNull { static constexpr auto name = "toUInt8OrNull"; };
|
||||
struct NameToUInt16OrNull { static constexpr auto name = "toUInt16OrNull"; };
|
||||
@ -1384,6 +1448,9 @@ struct NameToFloat32OrNull { static constexpr auto name = "toFloat32OrNull"; };
|
||||
struct NameToFloat64OrNull { static constexpr auto name = "toFloat64OrNull"; };
|
||||
struct NameToDateOrNull { static constexpr auto name = "toDateOrNull"; };
|
||||
struct NameToDateTimeOrNull { static constexpr auto name = "toDateTimeOrNull"; };
|
||||
struct NameToDecimal32OrNull { static constexpr auto name = "toDecimal32OrNull"; };
|
||||
struct NameToDecimal64OrNull { static constexpr auto name = "toDecimal64OrNull"; };
|
||||
struct NameToDecimal128OrNull { static constexpr auto name = "toDecimal128OrNull"; };
|
||||
|
||||
using FunctionToUInt8OrNull = FunctionConvertFromString<DataTypeUInt8, NameToUInt8OrNull, ConvertFromStringExceptionMode::Null>;
|
||||
using FunctionToUInt16OrNull = FunctionConvertFromString<DataTypeUInt16, NameToUInt16OrNull, ConvertFromStringExceptionMode::Null>;
|
||||
@ -1397,6 +1464,9 @@ using FunctionToFloat32OrNull = FunctionConvertFromString<DataTypeFloat32, NameT
|
||||
using FunctionToFloat64OrNull = FunctionConvertFromString<DataTypeFloat64, NameToFloat64OrNull, ConvertFromStringExceptionMode::Null>;
|
||||
using FunctionToDateOrNull = FunctionConvertFromString<DataTypeDate, NameToDateOrNull, ConvertFromStringExceptionMode::Null>;
|
||||
using FunctionToDateTimeOrNull = FunctionConvertFromString<DataTypeDateTime, NameToDateTimeOrNull, ConvertFromStringExceptionMode::Null>;
|
||||
using FunctionToDecimal32OrNull = FunctionConvertFromString<DataTypeDecimal<Decimal32>, NameToDecimal32OrNull, ConvertFromStringExceptionMode::Null>;
|
||||
using FunctionToDecimal64OrNull = FunctionConvertFromString<DataTypeDecimal<Decimal64>, NameToDecimal64OrNull, ConvertFromStringExceptionMode::Null>;
|
||||
using FunctionToDecimal128OrNull = FunctionConvertFromString<DataTypeDecimal<Decimal128>, NameToDecimal128OrNull, ConvertFromStringExceptionMode::Null>;
|
||||
|
||||
struct NameParseDateTimeBestEffort { static constexpr auto name = "parseDateTimeBestEffort"; };
|
||||
struct NameParseDateTimeBestEffortOrZero { static constexpr auto name = "parseDateTimeBestEffortOrZero"; };
|
||||
|
@ -1,19 +1,24 @@
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionsGeo.h>
|
||||
#include <Functions/GeoUtils.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
|
||||
#include <boost/geometry.hpp>
|
||||
#include <boost/geometry/geometries/point_xy.hpp>
|
||||
#include <boost/geometry/geometries/polygon.hpp>
|
||||
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Common/ProfileEvents.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/ObjectPool.h>
|
||||
#include <Common/ProfileEvents.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
@ -246,6 +251,186 @@ private:
|
||||
|
||||
};
|
||||
|
||||
const size_t GEOHASH_MAX_TEXT_LENGTH = 16;
|
||||
|
||||
// geohashEncode(lon float32/64, lat float32/64, length UInt8) => string
|
||||
class FunctionGeohashEncode : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "geohashEncode";
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionGeohashEncode>(); }
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {2}; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
validateArgumentType(*this, arguments, 0, isFloat, "float");
|
||||
validateArgumentType(*this, arguments, 1, isFloat, "float");
|
||||
if (arguments.size() == 3)
|
||||
{
|
||||
validateArgumentType(*this, arguments, 2, isInteger, "integer");
|
||||
}
|
||||
if (arguments.size() > 3)
|
||||
{
|
||||
throw Exception("Too many arguments for function " + getName() +
|
||||
" expected at most 3",
|
||||
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION);
|
||||
}
|
||||
|
||||
return std::make_shared<DataTypeString>();
|
||||
}
|
||||
|
||||
template <typename LonType, typename LatType>
|
||||
bool tryExecute(const IColumn * lon_column, const IColumn * lat_column, UInt64 precision_value, ColumnPtr & result)
|
||||
{
|
||||
const ColumnVector<LonType> * longitude = checkAndGetColumn<ColumnVector<LonType>>(lon_column);
|
||||
const ColumnVector<LatType> * latitude = checkAndGetColumn<ColumnVector<LatType>>(lat_column);
|
||||
if (!latitude || !longitude)
|
||||
return false;
|
||||
|
||||
auto col_str = ColumnString::create();
|
||||
ColumnString::Chars & out_vec = col_str->getChars();
|
||||
ColumnString::Offsets & out_offsets = col_str->getOffsets();
|
||||
|
||||
const size_t size = lat_column->size();
|
||||
|
||||
out_offsets.resize(size);
|
||||
out_vec.resize(size * (GEOHASH_MAX_TEXT_LENGTH + 1));
|
||||
|
||||
char * begin = reinterpret_cast<char *>(out_vec.data());
|
||||
char * pos = begin;
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
const Float64 longitude_value = longitude->getElement(i);
|
||||
const Float64 latitude_value = latitude->getElement(i);
|
||||
|
||||
const size_t encoded_size = GeoUtils::geohashEncode(longitude_value, latitude_value, precision_value, pos);
|
||||
|
||||
pos += encoded_size;
|
||||
*pos = '\0';
|
||||
out_offsets[i] = ++pos - begin;
|
||||
}
|
||||
out_vec.resize(pos - begin);
|
||||
|
||||
if (!out_offsets.empty() && out_offsets.back() != out_vec.size())
|
||||
throw Exception("Column size mismatch (internal logical error)", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
result = std::move(col_str);
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
{
|
||||
const IColumn * longitude = block.getByPosition(arguments[0]).column.get();
|
||||
const IColumn * latitude = block.getByPosition(arguments[1]).column.get();
|
||||
|
||||
const UInt64 precision_value = std::min(GEOHASH_MAX_TEXT_LENGTH,
|
||||
arguments.size() == 3 ? block.getByPosition(arguments[2]).column->get64(0) : GEOHASH_MAX_TEXT_LENGTH);
|
||||
|
||||
ColumnPtr & res_column = block.getByPosition(result).column;
|
||||
|
||||
if (tryExecute<Float32, Float32>(longitude, latitude, precision_value, res_column) ||
|
||||
tryExecute<Float64, Float32>(longitude, latitude, precision_value, res_column) ||
|
||||
tryExecute<Float32, Float64>(longitude, latitude, precision_value, res_column) ||
|
||||
tryExecute<Float64, Float64>(longitude, latitude, precision_value, res_column))
|
||||
return;
|
||||
|
||||
const char sep[] = ", ";
|
||||
std::string arguments_description = "";
|
||||
for (size_t i = 0; i < arguments.size(); ++i)
|
||||
{
|
||||
arguments_description += block.getByPosition(arguments[i]).column->getName() + sep;
|
||||
}
|
||||
if (arguments_description.size() > sizeof(sep))
|
||||
{
|
||||
arguments_description.erase(arguments_description.size() - sizeof(sep) - 1);
|
||||
}
|
||||
|
||||
throw Exception("Unsupported argument types: " + arguments_description +
|
||||
+ " for function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
};
|
||||
|
||||
// geohashDecode(string) => (lon float64, lat float64)
|
||||
class FunctionGeohashDecode : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "geohashDecode";
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionGeohashDecode>(); }
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
size_t getNumberOfArguments() const override { return 1; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
validateArgumentType(*this, arguments, 0, isStringOrFixedString, "string or fixed string");
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
DataTypes{std::make_shared<DataTypeFloat64>(), std::make_shared<DataTypeFloat64>()},
|
||||
Strings{"longitude", "latitude"});
|
||||
}
|
||||
|
||||
template <typename ColumnTypeEncoded>
|
||||
bool tryExecute(const IColumn * encoded_column, ColumnPtr & result_column)
|
||||
{
|
||||
const auto * encoded = checkAndGetColumn<ColumnTypeEncoded>(encoded_column);
|
||||
if (!encoded)
|
||||
return false;
|
||||
|
||||
const size_t count = encoded->size();
|
||||
|
||||
auto latitude = ColumnFloat64::create(count);
|
||||
auto longitude = ColumnFloat64::create(count);
|
||||
|
||||
ColumnFloat64::Container & lon_data = longitude->getData();
|
||||
ColumnFloat64::Container & lat_data = latitude->getData();
|
||||
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
StringRef encoded_string = encoded->getDataAt(i);
|
||||
GeoUtils::geohashDecode(encoded_string.data, encoded_string.size, &lon_data[i], &lat_data[i]);
|
||||
}
|
||||
|
||||
MutableColumns result;
|
||||
result.emplace_back(std::move(longitude));
|
||||
result.emplace_back(std::move(latitude));
|
||||
result_column = ColumnTuple::create(std::move(result));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
{
|
||||
const IColumn * encoded = block.getByPosition(arguments[0]).column.get();
|
||||
ColumnPtr & res_column = block.getByPosition(result).column;
|
||||
|
||||
if (tryExecute<ColumnString>(encoded, res_column) ||
|
||||
tryExecute<ColumnFixedString>(encoded, res_column))
|
||||
return;
|
||||
|
||||
throw Exception("Unsupported argument type:" + block.getByPosition(arguments[0]).column->getName()
|
||||
+ " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
using Point = boost::geometry::model::d2::point_xy<Type>;
|
||||
|
||||
@ -261,5 +446,7 @@ void registerFunctionsGeo(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionPointInEllipses>();
|
||||
|
||||
factory.registerFunction<FunctionPointInPolygon<PointInPolygonWithGrid, true>>();
|
||||
factory.registerFunction<FunctionGeohashEncode>();
|
||||
factory.registerFunction<FunctionGeohashDecode>();
|
||||
}
|
||||
}
|
||||
|
378
dbms/src/Functions/FunctionsJSON.cpp
Normal file
378
dbms/src/Functions/FunctionsJSON.cpp
Normal file
@ -0,0 +1,378 @@
|
||||
#include <Functions/FunctionsJSON.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Common/config.h>
|
||||
|
||||
#if USE_SIMDJSON
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template <typename T>
|
||||
class JSONNullableImplBase
|
||||
{
|
||||
public:
|
||||
static DataTypePtr getType() { return std::make_shared<DataTypeNullable>(std::make_shared<T>()); }
|
||||
|
||||
static Field getDefault() { return {}; }
|
||||
};
|
||||
|
||||
class JSONHasImpl : public JSONNullableImplBase<DataTypeUInt8>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonHas"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator &) { return {1}; }
|
||||
};
|
||||
|
||||
class JSONLengthImpl : public JSONNullableImplBase<DataTypeUInt64>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonLength"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh)
|
||||
{
|
||||
if (!pjh.is_object_or_array())
|
||||
return getDefault();
|
||||
|
||||
size_t size = 0;
|
||||
|
||||
if (pjh.down())
|
||||
{
|
||||
size += 1;
|
||||
|
||||
while (pjh.next())
|
||||
size += 1;
|
||||
}
|
||||
|
||||
return {size};
|
||||
}
|
||||
};
|
||||
|
||||
class JSONTypeImpl : public JSONNullableImplBase<DataTypeString>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonType"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh)
|
||||
{
|
||||
switch (pjh.get_type())
|
||||
{
|
||||
case '[':
|
||||
return "Array";
|
||||
case '{':
|
||||
return "Object";
|
||||
case '"':
|
||||
return "String";
|
||||
case 'l':
|
||||
return "Int64";
|
||||
case 'd':
|
||||
return "Float64";
|
||||
case 't':
|
||||
return "Bool";
|
||||
case 'f':
|
||||
return "Bool";
|
||||
case 'n':
|
||||
return "Null";
|
||||
default:
|
||||
return "Unknown";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class JSONExtractImpl
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonExtract"};
|
||||
|
||||
static DataTypePtr getType(const DataTypePtr & type)
|
||||
{
|
||||
WhichDataType which{type};
|
||||
|
||||
if (which.isNativeUInt() || which.isNativeInt() || which.isFloat() || which.isEnum() || which.isDateOrDateTime()
|
||||
|| which.isStringOrFixedString() || which.isInterval())
|
||||
return std::make_shared<DataTypeNullable>(type);
|
||||
|
||||
if (which.isArray())
|
||||
{
|
||||
auto array_type = static_cast<const DataTypeArray *>(type.get());
|
||||
|
||||
return std::make_shared<DataTypeArray>(getType(array_type->getNestedType()));
|
||||
}
|
||||
|
||||
if (which.isTuple())
|
||||
{
|
||||
auto tuple_type = static_cast<const DataTypeTuple *>(type.get());
|
||||
|
||||
DataTypes types;
|
||||
types.reserve(tuple_type->getElements().size());
|
||||
|
||||
for (const DataTypePtr & element : tuple_type->getElements())
|
||||
{
|
||||
types.push_back(getType(element));
|
||||
}
|
||||
|
||||
return std::make_shared<DataTypeTuple>(std::move(types));
|
||||
}
|
||||
|
||||
throw Exception{"Unsupported return type schema: " + type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
static Field getDefault(const DataTypePtr & type)
|
||||
{
|
||||
WhichDataType which{type};
|
||||
|
||||
if (which.isNativeUInt() || which.isNativeInt() || which.isFloat() || which.isEnum() || which.isDateOrDateTime()
|
||||
|| which.isStringOrFixedString() || which.isInterval())
|
||||
return {};
|
||||
|
||||
if (which.isArray())
|
||||
return {Array{}};
|
||||
|
||||
if (which.isTuple())
|
||||
{
|
||||
auto tuple_type = static_cast<const DataTypeTuple *>(type.get());
|
||||
|
||||
Tuple tuple;
|
||||
tuple.toUnderType().reserve(tuple_type->getElements().size());
|
||||
|
||||
for (const DataTypePtr & element : tuple_type->getElements())
|
||||
tuple.toUnderType().push_back(getDefault(element));
|
||||
|
||||
return {tuple};
|
||||
}
|
||||
|
||||
// should not reach
|
||||
throw Exception{"Unsupported return type schema: " + type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh, const DataTypePtr & type)
|
||||
{
|
||||
WhichDataType which{type};
|
||||
|
||||
if (which.isNativeUInt() || which.isNativeInt() || which.isEnum() || which.isDateOrDateTime() || which.isInterval())
|
||||
{
|
||||
if (pjh.is_integer())
|
||||
return {pjh.get_integer()};
|
||||
else
|
||||
return getDefault(type);
|
||||
}
|
||||
|
||||
if (which.isFloat())
|
||||
{
|
||||
if (pjh.is_integer())
|
||||
return {static_cast<double>(pjh.get_integer())};
|
||||
else if (pjh.is_double())
|
||||
return {pjh.get_double()};
|
||||
else
|
||||
return getDefault(type);
|
||||
}
|
||||
|
||||
if (which.isStringOrFixedString())
|
||||
{
|
||||
if (pjh.is_string())
|
||||
return {String{pjh.get_string()}};
|
||||
else
|
||||
return getDefault(type);
|
||||
}
|
||||
|
||||
if (which.isArray())
|
||||
{
|
||||
if (!pjh.is_object_or_array())
|
||||
return getDefault(type);
|
||||
|
||||
auto array_type = static_cast<const DataTypeArray *>(type.get());
|
||||
|
||||
Array array;
|
||||
|
||||
bool first = true;
|
||||
|
||||
while (first ? pjh.down() : pjh.next())
|
||||
{
|
||||
first = false;
|
||||
|
||||
ParsedJson::iterator pjh1{pjh};
|
||||
|
||||
array.push_back(getValue(pjh1, array_type->getNestedType()));
|
||||
}
|
||||
|
||||
return {array};
|
||||
}
|
||||
|
||||
if (which.isTuple())
|
||||
{
|
||||
if (!pjh.is_object_or_array())
|
||||
return getDefault(type);
|
||||
|
||||
auto tuple_type = static_cast<const DataTypeTuple *>(type.get());
|
||||
|
||||
Tuple tuple;
|
||||
tuple.toUnderType().reserve(tuple_type->getElements().size());
|
||||
|
||||
bool valid = true;
|
||||
bool first = true;
|
||||
|
||||
for (const DataTypePtr & element : tuple_type->getElements())
|
||||
{
|
||||
if (valid)
|
||||
{
|
||||
valid &= first ? pjh.down() : pjh.next();
|
||||
first = false;
|
||||
|
||||
ParsedJson::iterator pjh1{pjh};
|
||||
|
||||
tuple.toUnderType().push_back(getValue(pjh1, element));
|
||||
}
|
||||
else
|
||||
tuple.toUnderType().push_back(getDefault(element));
|
||||
}
|
||||
|
||||
return {tuple};
|
||||
}
|
||||
|
||||
// should not reach
|
||||
throw Exception{"Unsupported return type schema: " + type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
};
|
||||
|
||||
class JSONExtractUIntImpl : public JSONNullableImplBase<DataTypeUInt64>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonExtractUInt"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh)
|
||||
{
|
||||
if (pjh.is_integer())
|
||||
return {pjh.get_integer()};
|
||||
else
|
||||
return getDefault();
|
||||
}
|
||||
};
|
||||
|
||||
class JSONExtractIntImpl : public JSONNullableImplBase<DataTypeInt64>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonExtractInt"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh)
|
||||
{
|
||||
if (pjh.is_integer())
|
||||
return {pjh.get_integer()};
|
||||
else
|
||||
return getDefault();
|
||||
}
|
||||
};
|
||||
|
||||
class JSONExtractFloatImpl : public JSONNullableImplBase<DataTypeFloat64>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonExtractFloat"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh)
|
||||
{
|
||||
if (pjh.is_double())
|
||||
return {pjh.get_double()};
|
||||
else
|
||||
return getDefault();
|
||||
}
|
||||
};
|
||||
|
||||
class JSONExtractBoolImpl : public JSONNullableImplBase<DataTypeUInt8>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonExtractBool"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh)
|
||||
{
|
||||
if (pjh.get_type() == 't')
|
||||
return {1};
|
||||
else if (pjh.get_type() == 'f')
|
||||
return {0};
|
||||
else
|
||||
return getDefault();
|
||||
}
|
||||
};
|
||||
|
||||
// class JSONExtractRawImpl: public JSONNullableImplBase<DataTypeString>
|
||||
// {
|
||||
// public:
|
||||
// static constexpr auto name {"jsonExtractRaw"};
|
||||
|
||||
// static Field getValue(ParsedJson::iterator & pjh)
|
||||
// {
|
||||
// //
|
||||
// }
|
||||
// };
|
||||
|
||||
class JSONExtractStringImpl : public JSONNullableImplBase<DataTypeString>
|
||||
{
|
||||
public:
|
||||
static constexpr auto name{"jsonExtractString"};
|
||||
|
||||
static Field getValue(ParsedJson::iterator & pjh)
|
||||
{
|
||||
if (pjh.is_string())
|
||||
return {String{pjh.get_string()}};
|
||||
else
|
||||
return getDefault();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
#else
|
||||
namespace DB
|
||||
{
|
||||
struct JSONHasImpl { static constexpr auto name{"jsonHas"}; };
|
||||
struct JSONLengthImpl { static constexpr auto name{"jsonLength"}; };
|
||||
struct JSONTypeImpl { static constexpr auto name{"jsonType"}; };
|
||||
struct JSONExtractImpl { static constexpr auto name{"jsonExtract"}; };
|
||||
struct JSONExtractUIntImpl { static constexpr auto name{"jsonExtractUInt"}; };
|
||||
struct JSONExtractIntImpl { static constexpr auto name{"jsonExtractInt"}; };
|
||||
struct JSONExtractFloatImpl { static constexpr auto name{"jsonExtractFloat"}; };
|
||||
struct JSONExtractBoolImpl { static constexpr auto name{"jsonExtractBool"}; };
|
||||
//struct JSONExtractRawImpl { static constexpr auto name {"jsonExtractRaw"}; };
|
||||
struct JSONExtractStringImpl { static constexpr auto name{"jsonExtractString"}; };
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
void registerFunctionsJSON(FunctionFactory & factory)
|
||||
{
|
||||
#if USE_SIMDJSON
|
||||
if (__builtin_cpu_supports("avx2"))
|
||||
{
|
||||
factory.registerFunction<FunctionJSONBase<JSONHasImpl, false>>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONLengthImpl, false>>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONTypeImpl, false>>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONExtractImpl, true>>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONExtractUIntImpl, false>>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONExtractIntImpl, false>>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONExtractFloatImpl, false>>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONExtractBoolImpl, false>>();
|
||||
// factory.registerFunction<FunctionJSONBase<
|
||||
// JSONExtractRawImpl,
|
||||
// false
|
||||
// >>();
|
||||
factory.registerFunction<FunctionJSONBase<JSONExtractStringImpl, false>>();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
factory.registerFunction<FunctionJSONDummy<JSONHasImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONLengthImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONTypeImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONExtractImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONExtractUIntImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONExtractIntImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONExtractFloatImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONExtractBoolImpl>>();
|
||||
//factory.registerFunction<FunctionJSONDummy<JSONExtractRawImpl>>();
|
||||
factory.registerFunction<FunctionJSONDummy<JSONExtractStringImpl>>();
|
||||
}
|
||||
|
||||
}
|
243
dbms/src/Functions/FunctionsJSON.h
Normal file
243
dbms/src/Functions/FunctionsJSON.h
Normal file
@ -0,0 +1,243 @@
|
||||
#pragma once
|
||||
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Common/config.h>
|
||||
|
||||
#if USE_SIMDJSON
|
||||
|
||||
#include <Columns/ColumnConst.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <DataTypes/DataTypeFactory.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <ext/range.h>
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic ignored "-Wnewline-eof"
|
||||
#endif
|
||||
|
||||
#include <simdjson/jsonparser.h>
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int CANNOT_ALLOCATE_MEMORY;
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
template <typename Impl, bool ExtraArg>
|
||||
class FunctionJSONBase : public IFunction
|
||||
{
|
||||
private:
|
||||
enum class Action
|
||||
{
|
||||
key = 1,
|
||||
index = 2,
|
||||
};
|
||||
|
||||
mutable std::vector<Action> actions;
|
||||
mutable DataTypePtr virtual_type;
|
||||
|
||||
bool tryMove(ParsedJson::iterator & pjh, Action action, const Field & accessor)
|
||||
{
|
||||
switch (action)
|
||||
{
|
||||
case Action::key:
|
||||
if (!pjh.is_object() || !pjh.move_to_key(accessor.get<String>().data()))
|
||||
return false;
|
||||
|
||||
break;
|
||||
case Action::index:
|
||||
if (!pjh.is_object_or_array() || !pjh.down())
|
||||
return false;
|
||||
|
||||
int steps = accessor.get<Int64>();
|
||||
|
||||
if (steps > 0)
|
||||
steps -= 1;
|
||||
else if (steps < 0)
|
||||
{
|
||||
steps += 1;
|
||||
|
||||
ParsedJson::iterator pjh1{pjh};
|
||||
|
||||
while (pjh1.next())
|
||||
steps += 1;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
|
||||
for (const auto i : ext::range(0, steps))
|
||||
{
|
||||
(void)i;
|
||||
|
||||
if (!pjh.next())
|
||||
return false;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto name = Impl::name;
|
||||
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionJSONBase>(); }
|
||||
|
||||
String getName() const override { return Impl::name; }
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
if constexpr (ExtraArg)
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception{"Function " + getName() + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
|
||||
|
||||
auto col_type_const = typeid_cast<const ColumnConst *>(arguments[1].column.get());
|
||||
|
||||
if (!col_type_const)
|
||||
throw Exception{"Illegal non-const column " + arguments[1].column->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN};
|
||||
|
||||
virtual_type = DataTypeFactory::instance().get(col_type_const->getValue<String>());
|
||||
}
|
||||
else
|
||||
{
|
||||
if (arguments.size() < 1)
|
||||
throw Exception{"Function " + getName() + " requires at least one arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
|
||||
}
|
||||
|
||||
if (!isString(arguments[0].type))
|
||||
throw Exception{"Illegal type " + arguments[0].type->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
actions.reserve(arguments.size() - 1 - ExtraArg);
|
||||
|
||||
for (const auto i : ext::range(1 + ExtraArg, arguments.size()))
|
||||
{
|
||||
if (isString(arguments[i].type))
|
||||
actions.push_back(Action::key);
|
||||
else if (isInteger(arguments[i].type))
|
||||
actions.push_back(Action::index);
|
||||
else
|
||||
throw Exception{"Illegal type " + arguments[i].type->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
if constexpr (ExtraArg)
|
||||
return Impl::getType(virtual_type);
|
||||
else
|
||||
return Impl::getType();
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) override
|
||||
{
|
||||
MutableColumnPtr to{block.getByPosition(result_pos).type->createColumn()};
|
||||
to->reserve(input_rows_count);
|
||||
|
||||
const ColumnPtr & arg_json = block.getByPosition(arguments[0]).column;
|
||||
|
||||
auto col_json_const = typeid_cast<const ColumnConst *>(arg_json.get());
|
||||
|
||||
auto col_json_string
|
||||
= typeid_cast<const ColumnString *>(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get());
|
||||
|
||||
if (!col_json_string)
|
||||
throw Exception{"Illegal column " + arg_json->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
|
||||
|
||||
const ColumnString::Chars & chars = col_json_string->getChars();
|
||||
const ColumnString::Offsets & offsets = col_json_string->getOffsets();
|
||||
|
||||
size_t max_size = 1;
|
||||
|
||||
for (const auto i : ext::range(0, input_rows_count))
|
||||
if (max_size < offsets[i] - offsets[i - 1] - 1)
|
||||
max_size = offsets[i] - offsets[i - 1] - 1;
|
||||
|
||||
ParsedJson pj;
|
||||
if (!pj.allocateCapacity(max_size))
|
||||
throw Exception{"Can not allocate memory for " + std::to_string(max_size) + " units when parsing JSON",
|
||||
ErrorCodes::CANNOT_ALLOCATE_MEMORY};
|
||||
|
||||
for (const auto i : ext::range(0, input_rows_count))
|
||||
{
|
||||
bool ok = json_parse(&chars[offsets[i - 1]], offsets[i] - offsets[i - 1] - 1, pj) == 0;
|
||||
|
||||
ParsedJson::iterator pjh{pj};
|
||||
|
||||
for (const auto j : ext::range(0, actions.size()))
|
||||
{
|
||||
if (!ok)
|
||||
break;
|
||||
|
||||
ok = tryMove(pjh, actions[j], (*block.getByPosition(arguments[j + 1 + ExtraArg]).column)[i]);
|
||||
}
|
||||
|
||||
if (ok)
|
||||
{
|
||||
if constexpr (ExtraArg)
|
||||
to->insert(Impl::getValue(pjh, virtual_type));
|
||||
else
|
||||
to->insert(Impl::getValue(pjh));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr (ExtraArg)
|
||||
to->insert(Impl::getDefault(virtual_type));
|
||||
else
|
||||
to->insert(Impl::getDefault());
|
||||
}
|
||||
}
|
||||
|
||||
block.getByPosition(result_pos).column = std::move(to);
|
||||
}
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template <typename Impl>
|
||||
class FunctionJSONDummy : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = Impl::name;
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionJSONDummy>(); }
|
||||
|
||||
String getName() const override { return Impl::name; }
|
||||
bool isVariadic() const override { return true; }
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override
|
||||
{
|
||||
throw Exception{"Function " + getName() + " is not supported without AVX2", ErrorCodes::NOT_IMPLEMENTED};
|
||||
}
|
||||
|
||||
void executeImpl(Block &, const ColumnNumbers &, size_t, size_t) override
|
||||
{
|
||||
throw Exception{"Function " + getName() + " is not supported without AVX2", ErrorCodes::NOT_IMPLEMENTED};
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -449,44 +449,27 @@ struct NameMultiSearchFirstPositionCaseInsensitiveUTF8
|
||||
using FunctionPosition = FunctionsStringSearch<PositionImpl<PositionCaseSensitiveASCII>, NamePosition>;
|
||||
using FunctionPositionUTF8 = FunctionsStringSearch<PositionImpl<PositionCaseSensitiveUTF8>, NamePositionUTF8>;
|
||||
using FunctionPositionCaseInsensitive = FunctionsStringSearch<PositionImpl<PositionCaseInsensitiveASCII>, NamePositionCaseInsensitive>;
|
||||
using FunctionPositionCaseInsensitiveUTF8
|
||||
= FunctionsStringSearch<PositionImpl<PositionCaseInsensitiveUTF8>, NamePositionCaseInsensitiveUTF8>;
|
||||
using FunctionPositionCaseInsensitiveUTF8 = FunctionsStringSearch<PositionImpl<PositionCaseInsensitiveUTF8>, NamePositionCaseInsensitiveUTF8>;
|
||||
|
||||
using FunctionMultiSearchAllPositions
|
||||
= FunctionsMultiStringPosition<MultiSearchAllPositionsImpl<PositionCaseSensitiveASCII>, NameMultiSearchAllPositions>;
|
||||
using FunctionMultiSearchAllPositionsUTF8
|
||||
= FunctionsMultiStringPosition<MultiSearchAllPositionsImpl<PositionCaseSensitiveUTF8>, NameMultiSearchAllPositionsUTF8>;
|
||||
using FunctionMultiSearchAllPositionsCaseInsensitive
|
||||
= FunctionsMultiStringPosition<MultiSearchAllPositionsImpl<PositionCaseInsensitiveASCII>, NameMultiSearchAllPositionsCaseInsensitive>;
|
||||
using FunctionMultiSearchAllPositionsCaseInsensitiveUTF8 = FunctionsMultiStringPosition<
|
||||
MultiSearchAllPositionsImpl<PositionCaseInsensitiveUTF8>,
|
||||
NameMultiSearchAllPositionsCaseInsensitiveUTF8>;
|
||||
using FunctionMultiSearchAllPositions = FunctionsMultiStringPosition<MultiSearchAllPositionsImpl<PositionCaseSensitiveASCII>, NameMultiSearchAllPositions>;
|
||||
using FunctionMultiSearchAllPositionsUTF8 = FunctionsMultiStringPosition<MultiSearchAllPositionsImpl<PositionCaseSensitiveUTF8>, NameMultiSearchAllPositionsUTF8>;
|
||||
using FunctionMultiSearchAllPositionsCaseInsensitive = FunctionsMultiStringPosition<MultiSearchAllPositionsImpl<PositionCaseInsensitiveASCII>, NameMultiSearchAllPositionsCaseInsensitive>;
|
||||
using FunctionMultiSearchAllPositionsCaseInsensitiveUTF8 = FunctionsMultiStringPosition<MultiSearchAllPositionsImpl<PositionCaseInsensitiveUTF8>, NameMultiSearchAllPositionsCaseInsensitiveUTF8>;
|
||||
|
||||
using FunctionMultiSearch = FunctionsMultiStringSearch<MultiSearchImpl<PositionCaseSensitiveASCII>, NameMultiSearchAny>;
|
||||
using FunctionMultiSearchUTF8 = FunctionsMultiStringSearch<MultiSearchImpl<PositionCaseSensitiveUTF8>, NameMultiSearchAnyUTF8>;
|
||||
using FunctionMultiSearchCaseInsensitive
|
||||
= FunctionsMultiStringSearch<MultiSearchImpl<PositionCaseInsensitiveASCII>, NameMultiSearchAnyCaseInsensitive>;
|
||||
using FunctionMultiSearchCaseInsensitiveUTF8
|
||||
= FunctionsMultiStringSearch<MultiSearchImpl<PositionCaseInsensitiveUTF8>, NameMultiSearchAnyCaseInsensitiveUTF8>;
|
||||
using FunctionMultiSearchCaseInsensitive = FunctionsMultiStringSearch<MultiSearchImpl<PositionCaseInsensitiveASCII>, NameMultiSearchAnyCaseInsensitive>;
|
||||
using FunctionMultiSearchCaseInsensitiveUTF8 = FunctionsMultiStringSearch<MultiSearchImpl<PositionCaseInsensitiveUTF8>, NameMultiSearchAnyCaseInsensitiveUTF8>;
|
||||
|
||||
using FunctionMultiSearchFirstIndex
|
||||
= FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseSensitiveASCII>, NameMultiSearchFirstIndex>;
|
||||
using FunctionMultiSearchFirstIndexUTF8
|
||||
= FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseSensitiveUTF8>, NameMultiSearchFirstIndexUTF8>;
|
||||
using FunctionMultiSearchFirstIndexCaseInsensitive
|
||||
= FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseInsensitiveASCII>, NameMultiSearchFirstIndexCaseInsensitive>;
|
||||
using FunctionMultiSearchFirstIndexCaseInsensitiveUTF8
|
||||
= FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseInsensitiveUTF8>, NameMultiSearchFirstIndexCaseInsensitiveUTF8>;
|
||||
using FunctionMultiSearchFirstIndex = FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseSensitiveASCII>, NameMultiSearchFirstIndex>;
|
||||
using FunctionMultiSearchFirstIndexUTF8 = FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseSensitiveUTF8>, NameMultiSearchFirstIndexUTF8>;
|
||||
using FunctionMultiSearchFirstIndexCaseInsensitive = FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseInsensitiveASCII>, NameMultiSearchFirstIndexCaseInsensitive>;
|
||||
using FunctionMultiSearchFirstIndexCaseInsensitiveUTF8 = FunctionsMultiStringSearch<MultiSearchFirstIndexImpl<PositionCaseInsensitiveUTF8>, NameMultiSearchFirstIndexCaseInsensitiveUTF8>;
|
||||
|
||||
using FunctionMultiSearchFirstPosition
|
||||
= FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseSensitiveASCII>, NameMultiSearchFirstPosition>;
|
||||
using FunctionMultiSearchFirstPositionUTF8
|
||||
= FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseSensitiveUTF8>, NameMultiSearchFirstPositionUTF8>;
|
||||
using FunctionMultiSearchFirstPositionCaseInsensitive
|
||||
= FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseInsensitiveASCII>, NameMultiSearchFirstPositionCaseInsensitive>;
|
||||
using FunctionMultiSearchFirstPositionCaseInsensitiveUTF8 = FunctionsMultiStringSearch<
|
||||
MultiSearchFirstPositionImpl<PositionCaseInsensitiveUTF8>,
|
||||
NameMultiSearchFirstPositionCaseInsensitiveUTF8>;
|
||||
using FunctionMultiSearchFirstPosition = FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseSensitiveASCII>, NameMultiSearchFirstPosition>;
|
||||
using FunctionMultiSearchFirstPositionUTF8 = FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseSensitiveUTF8>, NameMultiSearchFirstPositionUTF8>;
|
||||
using FunctionMultiSearchFirstPositionCaseInsensitive = FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseInsensitiveASCII>, NameMultiSearchFirstPositionCaseInsensitive>;
|
||||
using FunctionMultiSearchFirstPositionCaseInsensitiveUTF8 = FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseInsensitiveUTF8>, NameMultiSearchFirstPositionCaseInsensitiveUTF8>;
|
||||
|
||||
|
||||
void registerFunctionsStringSearch(FunctionFactory & factory)
|
||||
|
@ -164,43 +164,46 @@ struct NgramDistanceImpl
|
||||
return num;
|
||||
}
|
||||
|
||||
template <bool SaveNgrams>
|
||||
static ALWAYS_INLINE inline size_t calculateNeedleStats(
|
||||
const char * data,
|
||||
const size_t size,
|
||||
NgramStats & ngram_stats,
|
||||
[[maybe_unused]] UInt16 * ngram_storage,
|
||||
size_t (*read_code_points)(CodePoint *, const char *&, const char *),
|
||||
UInt16 (*hash_functor)(const CodePoint *))
|
||||
{
|
||||
// To prevent size_t overflow below.
|
||||
if (size < N)
|
||||
return 0;
|
||||
|
||||
const char * start = data;
|
||||
const char * end = data + size;
|
||||
CodePoint cp[simultaneously_codepoints_num] = {};
|
||||
|
||||
/// read_code_points returns the position of cp where it stopped reading codepoints.
|
||||
size_t found = read_code_points(cp, start, end);
|
||||
/// We need to start for the first time here, because first N - 1 codepoints mean nothing.
|
||||
size_t i = N - 1;
|
||||
/// Initialize with this value because for the first time `found` does not initialize first N - 1 codepoints.
|
||||
size_t len = -N + 1;
|
||||
size_t len = 0;
|
||||
do
|
||||
{
|
||||
len += found - N + 1;
|
||||
for (; i + N <= found; ++i)
|
||||
++ngram_stats[hash_functor(cp + i)];
|
||||
{
|
||||
++len;
|
||||
UInt16 hash = hash_functor(cp + i);
|
||||
if constexpr (SaveNgrams)
|
||||
*ngram_storage++ = hash;
|
||||
++ngram_stats[hash];
|
||||
}
|
||||
i = 0;
|
||||
} while (start < end && (found = read_code_points(cp, start, end)));
|
||||
|
||||
return len;
|
||||
}
|
||||
|
||||
template <bool ReuseStats>
|
||||
static ALWAYS_INLINE inline UInt64 calculateHaystackStatsAndMetric(
|
||||
const char * data,
|
||||
const size_t size,
|
||||
NgramStats & ngram_stats,
|
||||
size_t & distance,
|
||||
[[maybe_unused]] UInt16 * ngram_storage,
|
||||
size_t (*read_code_points)(CodePoint *, const char *&, const char *),
|
||||
UInt16 (*hash_functor)(const CodePoint *))
|
||||
{
|
||||
@ -209,18 +212,6 @@ struct NgramDistanceImpl
|
||||
const char * end = data + size;
|
||||
CodePoint cp[simultaneously_codepoints_num] = {};
|
||||
|
||||
/// allocation tricks, most strings are relatively small
|
||||
static constexpr size_t small_buffer_size = 256;
|
||||
std::unique_ptr<UInt16[]> big_buffer;
|
||||
UInt16 small_buffer[small_buffer_size];
|
||||
UInt16 * ngram_storage = small_buffer;
|
||||
|
||||
if (size > small_buffer_size)
|
||||
{
|
||||
ngram_storage = new UInt16[size];
|
||||
big_buffer.reset(ngram_storage);
|
||||
}
|
||||
|
||||
/// read_code_points returns the position of cp where it stopped reading codepoints.
|
||||
size_t found = read_code_points(cp, start, end);
|
||||
/// We need to start for the first time here, because first N - 1 codepoints mean nothing.
|
||||
@ -235,21 +226,25 @@ struct NgramDistanceImpl
|
||||
--distance;
|
||||
else
|
||||
++distance;
|
||||
|
||||
ngram_storage[ngram_cnt++] = hash;
|
||||
if constexpr (ReuseStats)
|
||||
ngram_storage[ngram_cnt] = hash;
|
||||
++ngram_cnt;
|
||||
--ngram_stats[hash];
|
||||
}
|
||||
iter = 0;
|
||||
} while (start < end && (found = read_code_points(cp, start, end)));
|
||||
|
||||
/// Return the state of hash map to its initial.
|
||||
for (size_t i = 0; i < ngram_cnt; ++i)
|
||||
++ngram_stats[ngram_storage[i]];
|
||||
if constexpr (ReuseStats)
|
||||
{
|
||||
for (size_t i = 0; i < ngram_cnt; ++i)
|
||||
++ngram_stats[ngram_storage[i]];
|
||||
}
|
||||
return ngram_cnt;
|
||||
}
|
||||
|
||||
template <class Callback, class... Args>
|
||||
static inline size_t dispatchSearcher(Callback callback, Args &&... args)
|
||||
static inline auto dispatchSearcher(Callback callback, Args &&... args)
|
||||
{
|
||||
if constexpr (!UTF8)
|
||||
return callback(std::forward<Args>(args)..., readASCIICodePoints, ASCIIHash);
|
||||
@ -259,8 +254,7 @@ struct NgramDistanceImpl
|
||||
|
||||
static void constant_constant(std::string data, std::string needle, Float32 & res)
|
||||
{
|
||||
NgramStats common_stats;
|
||||
memset(common_stats, 0, sizeof(common_stats));
|
||||
NgramStats common_stats = {};
|
||||
|
||||
/// We use unsafe versions of getting ngrams, so I decided to use padded strings.
|
||||
const size_t needle_size = needle.size();
|
||||
@ -268,11 +262,11 @@ struct NgramDistanceImpl
|
||||
needle.resize(needle_size + default_padding);
|
||||
data.resize(data_size + default_padding);
|
||||
|
||||
size_t second_size = dispatchSearcher(calculateNeedleStats, needle.data(), needle_size, common_stats);
|
||||
size_t second_size = dispatchSearcher(calculateNeedleStats<false>, needle.data(), needle_size, common_stats, nullptr);
|
||||
size_t distance = second_size;
|
||||
if (data_size <= max_string_size)
|
||||
{
|
||||
size_t first_size = dispatchSearcher(calculateHaystackStatsAndMetric, data.data(), data_size, common_stats, distance);
|
||||
size_t first_size = dispatchSearcher(calculateHaystackStatsAndMetric<false>, data.data(), data_size, common_stats, distance, nullptr);
|
||||
res = distance * 1.f / std::max(first_size + second_size, size_t(1));
|
||||
}
|
||||
else
|
||||
@ -281,18 +275,89 @@ struct NgramDistanceImpl
|
||||
}
|
||||
}
|
||||
|
||||
static void vector_vector(
|
||||
const ColumnString::Chars & haystack_data,
|
||||
const ColumnString::Offsets & haystack_offsets,
|
||||
const ColumnString::Chars & needle_data,
|
||||
const ColumnString::Offsets & needle_offsets,
|
||||
PaddedPODArray<Float32> & res)
|
||||
{
|
||||
const size_t haystack_offsets_size = haystack_offsets.size();
|
||||
size_t prev_haystack_offset = 0;
|
||||
size_t prev_needle_offset = 0;
|
||||
|
||||
NgramStats common_stats = {};
|
||||
|
||||
/// The main motivation is to not allocate more on stack because we have already allocated a lot (128Kb).
|
||||
/// And we can reuse these storages in one thread because we care only about what was written to first places.
|
||||
std::unique_ptr<UInt16[]> needle_ngram_storage(new UInt16[max_string_size]);
|
||||
std::unique_ptr<UInt16[]> haystack_ngram_storage(new UInt16[max_string_size]);
|
||||
|
||||
for (size_t i = 0; i < haystack_offsets_size; ++i)
|
||||
{
|
||||
const char * haystack = reinterpret_cast<const char *>(&haystack_data[prev_haystack_offset]);
|
||||
const size_t haystack_size = haystack_offsets[i] - prev_haystack_offset - 1;
|
||||
const char * needle = reinterpret_cast<const char *>(&needle_data[prev_needle_offset]);
|
||||
const size_t needle_size = needle_offsets[i] - prev_needle_offset - 1;
|
||||
|
||||
if (needle_size <= max_string_size && haystack_size <= max_string_size)
|
||||
{
|
||||
/// Get needle stats.
|
||||
const size_t needle_stats_size = dispatchSearcher(
|
||||
calculateNeedleStats<true>,
|
||||
needle,
|
||||
needle_size,
|
||||
common_stats,
|
||||
needle_ngram_storage.get());
|
||||
|
||||
size_t distance = needle_stats_size;
|
||||
|
||||
/// Combine with haystack stats, return to initial needle stats.
|
||||
const size_t haystack_stats_size = dispatchSearcher(
|
||||
calculateHaystackStatsAndMetric<true>,
|
||||
haystack,
|
||||
haystack_size,
|
||||
common_stats,
|
||||
distance,
|
||||
haystack_ngram_storage.get());
|
||||
|
||||
/// Return to zero array stats.
|
||||
for (size_t j = 0; j < needle_stats_size; ++j)
|
||||
--common_stats[needle_ngram_storage[j]];
|
||||
|
||||
/// For now, common stats is a zero array.
|
||||
res[i] = distance * 1.f / std::max(haystack_stats_size + needle_stats_size, size_t(1));
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Strings are too big, we are assuming they are not the same. This is done because of limiting number
|
||||
/// of bigrams added and not allocating too much memory.
|
||||
res[i] = 1.f;
|
||||
}
|
||||
|
||||
prev_needle_offset = needle_offsets[i];
|
||||
prev_haystack_offset = haystack_offsets[i];
|
||||
}
|
||||
}
|
||||
|
||||
static void vector_constant(
|
||||
const ColumnString::Chars & data, const ColumnString::Offsets & offsets, std::string needle, PaddedPODArray<Float32> & res)
|
||||
const ColumnString::Chars & data,
|
||||
const ColumnString::Offsets & offsets,
|
||||
std::string needle,
|
||||
PaddedPODArray<Float32> & res)
|
||||
{
|
||||
/// zeroing our map
|
||||
NgramStats common_stats;
|
||||
memset(common_stats, 0, sizeof(common_stats));
|
||||
NgramStats common_stats = {};
|
||||
|
||||
/// The main motivation is to not allocate more on stack because we have already allocated a lot (128Kb).
|
||||
/// And we can reuse these storages in one thread because we care only about what was written to first places.
|
||||
std::unique_ptr<UInt16[]> ngram_storage(new UInt16[max_string_size]);
|
||||
|
||||
/// We use unsafe versions of getting ngrams, so I decided to use padded_data even in needle case.
|
||||
const size_t needle_size = needle.size();
|
||||
needle.resize(needle_size + default_padding);
|
||||
|
||||
const size_t needle_stats_size = dispatchSearcher(calculateNeedleStats, needle.data(), needle_size, common_stats);
|
||||
const size_t needle_stats_size = dispatchSearcher(calculateNeedleStats<false>, needle.data(), needle_size, common_stats, nullptr);
|
||||
|
||||
size_t distance = needle_stats_size;
|
||||
size_t prev_offset = 0;
|
||||
@ -303,7 +368,11 @@ struct NgramDistanceImpl
|
||||
if (haystack_size <= max_string_size)
|
||||
{
|
||||
size_t haystack_stats_size = dispatchSearcher(
|
||||
calculateHaystackStatsAndMetric, reinterpret_cast<const char *>(haystack), haystack_size, common_stats, distance);
|
||||
calculateHaystackStatsAndMetric<true>,
|
||||
reinterpret_cast<const char *>(haystack),
|
||||
haystack_size, common_stats,
|
||||
distance,
|
||||
ngram_storage.get());
|
||||
res[i] = distance * 1.f / std::max(haystack_stats_size + needle_stats_size, size_t(1));
|
||||
}
|
||||
else
|
||||
@ -339,11 +408,9 @@ struct NameNgramDistanceUTF8CaseInsensitive
|
||||
};
|
||||
|
||||
using FunctionNgramDistance = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, false>, NameNgramDistance>;
|
||||
using FunctionNgramDistanceCaseInsensitive
|
||||
= FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true>, NameNgramDistanceCaseInsensitive>;
|
||||
using FunctionNgramDistanceCaseInsensitive = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true>, NameNgramDistanceCaseInsensitive>;
|
||||
using FunctionNgramDistanceUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, false>, NameNgramDistanceUTF8>;
|
||||
using FunctionNgramDistanceCaseInsensitiveUTF8
|
||||
= FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true>, NameNgramDistanceUTF8CaseInsensitive>;
|
||||
using FunctionNgramDistanceCaseInsensitiveUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true>, NameNgramDistanceUTF8CaseInsensitive>;
|
||||
|
||||
void registerFunctionsStringSimilarity(FunctionFactory & factory)
|
||||
{
|
||||
|
@ -62,10 +62,7 @@ public:
|
||||
const ColumnConst * col_haystack_const = typeid_cast<const ColumnConst *>(&*column_haystack);
|
||||
const ColumnConst * col_needle_const = typeid_cast<const ColumnConst *>(&*column_needle);
|
||||
|
||||
if (!col_needle_const)
|
||||
throw Exception("Second argument of function " + getName() + " must be constant string.", ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
if (col_haystack_const)
|
||||
if (col_haystack_const && col_needle_const)
|
||||
{
|
||||
ResultType res{};
|
||||
const String & needle = col_needle_const->getValue<String>();
|
||||
@ -88,8 +85,9 @@ public:
|
||||
vec_res.resize(column_haystack->size());
|
||||
|
||||
const ColumnString * col_haystack_vector = checkAndGetColumn<ColumnString>(&*column_haystack);
|
||||
const ColumnString * col_needle_vector = checkAndGetColumn<ColumnString>(&*column_needle);
|
||||
|
||||
if (col_haystack_vector)
|
||||
if (col_haystack_vector && col_needle_const)
|
||||
{
|
||||
const String & needle = col_needle_const->getValue<String>();
|
||||
if (needle.size() > Impl::max_string_size)
|
||||
@ -101,6 +99,27 @@ public:
|
||||
}
|
||||
Impl::vector_constant(col_haystack_vector->getChars(), col_haystack_vector->getOffsets(), needle, vec_res);
|
||||
}
|
||||
else if (col_haystack_vector && col_needle_vector)
|
||||
{
|
||||
Impl::vector_vector(
|
||||
col_haystack_vector->getChars(),
|
||||
col_haystack_vector->getOffsets(),
|
||||
col_needle_vector->getChars(),
|
||||
col_needle_vector->getOffsets(),
|
||||
vec_res);
|
||||
}
|
||||
else if (col_haystack_const && col_needle_vector)
|
||||
{
|
||||
const String & needle = col_haystack_const->getValue<String>();
|
||||
if (needle.size() > Impl::max_string_size)
|
||||
{
|
||||
throw Exception(
|
||||
"String size of needle is too big for function " + getName() + ". Should be at most "
|
||||
+ std::to_string(Impl::max_string_size),
|
||||
ErrorCodes::TOO_LARGE_STRING_SIZE);
|
||||
}
|
||||
Impl::vector_constant(col_needle_vector->getChars(), col_needle_vector->getOffsets(), needle, vec_res);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception(
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <Functions/GatherUtils/ArraySourceVisitor.h>
|
||||
#include <Functions/GatherUtils/ArraySinkVisitor.h>
|
||||
#include <Functions/GatherUtils/ValueSourceVisitor.h>
|
||||
#include <Core/TypeListNumber.h>
|
||||
|
||||
|
||||
namespace DB::GatherUtils
|
||||
{
|
||||
|
232
dbms/src/Functions/GeoUtils.cpp
Normal file
232
dbms/src/Functions/GeoUtils.cpp
Normal file
@ -0,0 +1,232 @@
|
||||
#include <Core/Types.h>
|
||||
#include <Functions/GeoUtils.h>
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
using namespace DB;
|
||||
|
||||
const char geohash_base32_encode_lookup_table[32] = {
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
|
||||
'b', 'c', 'd', 'e', 'f', 'g', 'h', 'j', 'k', 'm',
|
||||
'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
|
||||
'y', 'z',
|
||||
};
|
||||
|
||||
// TODO: this could be halved by excluding 128-255 range.
|
||||
const UInt8 geohash_base32_decode_lookup_table[256] = {
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 10, 11, 12, 13, 14, 15, 16, 0xFF, 17, 18, 0xFF, 19, 20, 0xFF,
|
||||
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
};
|
||||
|
||||
const size_t BITS_PER_SYMBOL = 5;
|
||||
const size_t MAX_PRECISION = 12;
|
||||
const size_t MAX_BITS = MAX_PRECISION * BITS_PER_SYMBOL * 1.5;
|
||||
|
||||
using Encoded = std::array<UInt8, MAX_BITS>;
|
||||
|
||||
enum CoordType
|
||||
{
|
||||
LATITUDE,
|
||||
LONGITUDE,
|
||||
};
|
||||
|
||||
inline UInt8 singleCoordBitsPrecision(UInt8 precision, CoordType type)
|
||||
{
|
||||
// Single coordinate occupies only half of the total bits.
|
||||
const UInt8 bits = (precision * BITS_PER_SYMBOL) / 2;
|
||||
if (precision & 0x1 && type == LONGITUDE)
|
||||
{
|
||||
return bits + 1;
|
||||
}
|
||||
|
||||
return bits;
|
||||
}
|
||||
|
||||
inline Encoded encodeCoordinate(Float64 coord, Float64 min, Float64 max, UInt8 bits)
|
||||
{
|
||||
Encoded result;
|
||||
result.fill(0);
|
||||
|
||||
for (int i = 0; i < bits; ++i)
|
||||
{
|
||||
Float64 mid = (max + min) / 2;
|
||||
if (coord >= mid)
|
||||
{
|
||||
result[i] = 1;
|
||||
min = mid;
|
||||
}
|
||||
else
|
||||
{
|
||||
result[i] = 0;
|
||||
max = mid;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline Float64 decodeCoordinate(const Encoded & coord, Float64 min, Float64 max, UInt8 bits)
|
||||
{
|
||||
Float64 mid = (max + min) / 2;
|
||||
for (int i = 0; i < bits; ++i)
|
||||
{
|
||||
const auto c = coord[i];
|
||||
if (c == 1)
|
||||
{
|
||||
min = mid;
|
||||
}
|
||||
else
|
||||
{
|
||||
max = mid;
|
||||
}
|
||||
|
||||
mid = (max + min) / 2;
|
||||
}
|
||||
|
||||
return mid;
|
||||
}
|
||||
|
||||
inline Encoded merge(const Encoded & encodedLon, const Encoded & encodedLat, UInt8 precision)
|
||||
{
|
||||
Encoded result;
|
||||
result.fill(0);
|
||||
|
||||
const auto bits = (precision * BITS_PER_SYMBOL) / 2;
|
||||
UInt8 i = 0;
|
||||
for (; i < bits; ++i)
|
||||
{
|
||||
result[i * 2 + 0] = encodedLon[i];
|
||||
result[i * 2 + 1] = encodedLat[i];
|
||||
}
|
||||
// in case of even precision, add last bit of longitude
|
||||
if (precision & 0x1)
|
||||
{
|
||||
result[i * 2] = encodedLon[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline std::tuple<Encoded, Encoded> split(const Encoded & combined, UInt8 precision)
|
||||
{
|
||||
Encoded lat, lon;
|
||||
lat.fill(0);
|
||||
lon.fill(0);
|
||||
|
||||
UInt8 i = 0;
|
||||
for (; i < precision * BITS_PER_SYMBOL - 1; i += 2)
|
||||
{
|
||||
// longitude is even bits
|
||||
lon[i/2] = combined[i];
|
||||
lat[i/2] = combined[i + 1];
|
||||
}
|
||||
// precision is even, read the last bit as lat.
|
||||
if (precision & 0x1)
|
||||
{
|
||||
lon[i/2] = combined[precision * BITS_PER_SYMBOL - 1];
|
||||
}
|
||||
|
||||
return std::tie(lon, lat);
|
||||
}
|
||||
|
||||
inline void base32Encode(const Encoded & binary, UInt8 precision, char * out)
|
||||
{
|
||||
extern const char geohash_base32_encode_lookup_table[32];
|
||||
|
||||
for (UInt8 i = 0; i < precision * BITS_PER_SYMBOL; i += 5)
|
||||
{
|
||||
UInt8 v = binary[i];
|
||||
v <<= 1;
|
||||
v |= binary[i + 1];
|
||||
v <<= 1;
|
||||
v |= binary[i + 2];
|
||||
v <<= 1;
|
||||
v |= binary[i + 3];
|
||||
v <<= 1;
|
||||
v |= binary[i + 4];
|
||||
|
||||
assert(v < 32);
|
||||
|
||||
*out = geohash_base32_encode_lookup_table[v];
|
||||
++out;
|
||||
}
|
||||
}
|
||||
|
||||
inline Encoded base32Decode(const char * encoded_string, size_t encoded_length)
|
||||
{
|
||||
extern const UInt8 geohash_base32_decode_lookup_table[256];
|
||||
|
||||
Encoded result;
|
||||
|
||||
for (size_t i = 0; i < encoded_length; ++i)
|
||||
{
|
||||
const UInt8 c = static_cast<UInt8>(encoded_string[i]);
|
||||
const UInt8 decoded = geohash_base32_decode_lookup_table[c] & 0x1F;
|
||||
result[i * 5 + 4] = (decoded >> 0) & 0x01;
|
||||
result[i * 5 + 3] = (decoded >> 1) & 0x01;
|
||||
result[i * 5 + 2] = (decoded >> 2) & 0x01;
|
||||
result[i * 5 + 1] = (decoded >> 3) & 0x01;
|
||||
result[i * 5 + 0] = (decoded >> 4) & 0x01;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace GeoUtils
|
||||
{
|
||||
|
||||
size_t geohashEncode(Float64 longitude, Float64 latitude, UInt8 precision, char *& out)
|
||||
{
|
||||
if (precision == 0 || precision > MAX_PRECISION)
|
||||
{
|
||||
precision = MAX_PRECISION;
|
||||
}
|
||||
|
||||
const Encoded combined = merge(
|
||||
encodeCoordinate(longitude, -180, 180, singleCoordBitsPrecision(precision, LONGITUDE)),
|
||||
encodeCoordinate(latitude, -90, 90, singleCoordBitsPrecision(precision, LATITUDE)),
|
||||
precision);
|
||||
|
||||
base32Encode(combined, precision, out);
|
||||
|
||||
return precision;
|
||||
}
|
||||
|
||||
void geohashDecode(const char * encoded_string, size_t encoded_len, Float64 * longitude, Float64 * latitude)
|
||||
{
|
||||
const UInt8 precision = std::min(encoded_len, MAX_PRECISION);
|
||||
if (precision == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
Encoded lat_encoded, lon_encoded;
|
||||
std::tie(lon_encoded, lat_encoded) = split(base32Decode(encoded_string, precision), precision);
|
||||
|
||||
*longitude = decodeCoordinate(lon_encoded, -180, 180, singleCoordBitsPrecision(precision, LONGITUDE));
|
||||
*latitude = decodeCoordinate(lat_encoded, -90, 90, singleCoordBitsPrecision(precision, LATITUDE));
|
||||
}
|
||||
|
||||
} // namespace GeoUtils
|
||||
|
||||
} // namespace DB
|
@ -699,6 +699,10 @@ std::string serialize(Polygon && polygon)
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t geohashEncode(Float64 longitude, Float64 latitude, UInt8 precision, char *& out);
|
||||
|
||||
void geohashDecode(const char * encoded_string, size_t encoded_len, Float64 * longitude, Float64 * latitude);
|
||||
|
||||
|
||||
} /// GeoUtils
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <Common/ObjectPool.h>
|
||||
#include <Common/OptimizedRegularExpression.h>
|
||||
#include <Common/ProfileEvents.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <common/StringRef.h>
|
||||
|
||||
|
||||
@ -135,9 +136,10 @@ namespace MultiRegexps
|
||||
for (const StringRef ref : str_patterns)
|
||||
{
|
||||
ptrns.push_back(ref.data);
|
||||
flags.push_back(HS_FLAG_DOTALL | HS_FLAG_ALLOWEMPTY | HS_FLAG_SINGLEMATCH);
|
||||
flags.push_back(HS_FLAG_DOTALL | HS_FLAG_ALLOWEMPTY | HS_FLAG_SINGLEMATCH | HS_FLAG_UTF8);
|
||||
if constexpr (CompileForEditDistance)
|
||||
{
|
||||
flags.back() &= ~HS_FLAG_UTF8;
|
||||
ext_exprs.emplace_back();
|
||||
ext_exprs.back().flags = HS_EXT_FLAG_EDIT_DISTANCE;
|
||||
ext_exprs.back().edit_distance = edit_distance.value();
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user