mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-13 18:02:24 +00:00
Merge branch 'master' into database_atomic
This commit is contained in:
commit
4d23c5e4d4
@ -198,11 +198,11 @@ if(WITH_COVERAGE AND COMPILER_GCC)
|
||||
endif()
|
||||
|
||||
set (CMAKE_BUILD_COLOR_MAKEFILE ON)
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COMPILER_FLAGS} ${PLATFORM_EXTRA_CXX_FLAG} -fno-omit-frame-pointer ${COMMON_WARNING_FLAGS} ${CXX_WARNING_FLAGS}")
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COMPILER_FLAGS} ${PLATFORM_EXTRA_CXX_FLAG} ${COMMON_WARNING_FLAGS} ${CXX_WARNING_FLAGS}")
|
||||
set (CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3 ${CMAKE_CXX_FLAGS_ADD}")
|
||||
set (CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb3 -fno-inline ${CMAKE_CXX_FLAGS_ADD}")
|
||||
|
||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${COMPILER_FLAGS} -fno-omit-frame-pointer ${COMMON_WARNING_FLAGS} ${CMAKE_C_FLAGS_ADD}")
|
||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${COMPILER_FLAGS} ${COMMON_WARNING_FLAGS} ${CMAKE_C_FLAGS_ADD}")
|
||||
set (CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} -O3 ${CMAKE_C_FLAGS_ADD}")
|
||||
set (CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O0 -g3 -ggdb3 -fno-inline ${CMAKE_C_FLAGS_ADD}")
|
||||
|
||||
@ -382,16 +382,12 @@ add_subdirectory (contrib EXCLUDE_FROM_ALL)
|
||||
|
||||
macro (add_executable target)
|
||||
# invoke built-in add_executable
|
||||
_add_executable (${ARGV})
|
||||
# explicitly acquire and interpose malloc symbols by clickhouse_malloc
|
||||
_add_executable (${ARGV} $<TARGET_OBJECTS:clickhouse_malloc>)
|
||||
get_target_property (type ${target} TYPE)
|
||||
if (${type} STREQUAL EXECUTABLE)
|
||||
file (RELATIVE_PATH dir ${CMAKE_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
if (${dir} MATCHES "^dbms")
|
||||
# Only interpose operator::new/delete for dbms executables (MemoryTracker stuff)
|
||||
target_link_libraries (${target} PRIVATE clickhouse_new_delete ${MALLOC_LIBRARIES})
|
||||
else ()
|
||||
target_link_libraries (${target} PRIVATE ${MALLOC_LIBRARIES})
|
||||
endif ()
|
||||
# operator::new/delete for executables (MemoryTracker stuff)
|
||||
target_link_libraries (${target} PRIVATE clickhouse_new_delete ${MALLOC_LIBRARIES})
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
|
@ -14,5 +14,4 @@ ClickHouse is an open-source column-oriented database management system that all
|
||||
|
||||
## Upcoming Events
|
||||
|
||||
* [ClickHouse Meetup in San Francisco](https://www.eventbrite.com/e/clickhouse-december-meetup-registration-78642047481) on December 3.
|
||||
* [ClickHouse Meetup in Moscow](https://yandex.ru/promo/clickhouse/moscow-december-2019) on December 11.
|
||||
|
@ -1,5 +1,8 @@
|
||||
find_library (TERMCAP_LIBRARY termcap)
|
||||
find_library (TERMCAP_LIBRARY tinfo)
|
||||
if (NOT TERMCAP_LIBRARY)
|
||||
find_library (TERMCAP_LIBRARY tinfo)
|
||||
find_library (TERMCAP_LIBRARY ncurses)
|
||||
endif()
|
||||
if (NOT TERMCAP_LIBRARY)
|
||||
find_library (TERMCAP_LIBRARY termcap)
|
||||
endif()
|
||||
message (STATUS "Using termcap: ${TERMCAP_LIBRARY}")
|
||||
|
@ -20,16 +20,38 @@ else ()
|
||||
message (WARNING "You are using an unsupported compiler. Compilation has only been tested with Clang 6+ and GCC 7+.")
|
||||
endif ()
|
||||
|
||||
STRING(REGEX MATCHALL "[0-9]+" COMPILER_VERSION_LIST ${CMAKE_CXX_COMPILER_VERSION})
|
||||
LIST(GET COMPILER_VERSION_LIST 0 COMPILER_VERSION_MAJOR)
|
||||
|
||||
option (LINKER_NAME "Linker name or full path")
|
||||
if (COMPILER_GCC)
|
||||
find_program (LLD_PATH NAMES "ld.lld")
|
||||
find_program (GOLD_PATH NAMES "ld.gold")
|
||||
else ()
|
||||
find_program (LLD_PATH NAMES "ld.lld-${COMPILER_VERSION_MAJOR}" "lld-${COMPILER_VERSION_MAJOR}" "ld.lld" "lld")
|
||||
find_program (GOLD_PATH NAMES "ld.gold" "gold")
|
||||
endif ()
|
||||
|
||||
find_program (LLD_PATH NAMES "ld.lld" "lld")
|
||||
find_program (GOLD_PATH NAMES "ld.gold" "gold")
|
||||
|
||||
# We prefer LLD linker over Gold or BFD.
|
||||
if (NOT LINKER_NAME)
|
||||
if (LLD_PATH)
|
||||
set (LINKER_NAME "lld")
|
||||
elseif (GOLD_PATH)
|
||||
set (LINKER_NAME "gold")
|
||||
if (COMPILER_GCC)
|
||||
# GCC driver requires one of supported linker names like "lld".
|
||||
set (LINKER_NAME "lld")
|
||||
else ()
|
||||
# Clang driver simply allows full linker path.
|
||||
set (LINKER_NAME ${LLD_PATH})
|
||||
endif ()
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (NOT LINKER_NAME)
|
||||
if (GOLD_PATH)
|
||||
if (COMPILER_GCC)
|
||||
set (LINKER_NAME "gold")
|
||||
else ()
|
||||
set (LINKER_NAME ${GOLD_PATH})
|
||||
endif ()
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
|
1
contrib/CMakeLists.txt
vendored
1
contrib/CMakeLists.txt
vendored
@ -52,6 +52,7 @@ if (USE_INTERNAL_BTRIE_LIBRARY)
|
||||
endif ()
|
||||
|
||||
if (USE_INTERNAL_ZLIB_LIBRARY)
|
||||
unset (BUILD_SHARED_LIBS CACHE)
|
||||
set (ZLIB_ENABLE_TESTS 0 CACHE INTERNAL "")
|
||||
set (SKIP_INSTALL_ALL 1 CACHE INTERNAL "")
|
||||
set (ZLIB_COMPAT 1 CACHE INTERNAL "") # also enables WITH_GZFILEOP
|
||||
|
2
contrib/poco
vendored
2
contrib/poco
vendored
@ -1 +1 @@
|
||||
Subproject commit 2b273bfe9db89429b2040c024484dee0197e48c7
|
||||
Subproject commit d478f62bd93c9cd14eb343756ef73a4ae622ddf5
|
2
contrib/zlib-ng
vendored
2
contrib/zlib-ng
vendored
@ -1 +1 @@
|
||||
Subproject commit cff0f500d9399d7cd3b9461a693d211e4b86fcc9
|
||||
Subproject commit bba56a73be249514acfbc7d49aa2a68994dad8ab
|
@ -100,7 +100,7 @@ set(dbms_sources)
|
||||
add_headers_and_sources(clickhouse_common_io src/Common)
|
||||
add_headers_and_sources(clickhouse_common_io src/Common/HashTable)
|
||||
add_headers_and_sources(clickhouse_common_io src/IO)
|
||||
list (REMOVE_ITEM clickhouse_common_io_sources src/Common/new_delete.cpp)
|
||||
list (REMOVE_ITEM clickhouse_common_io_sources src/Common/malloc.cpp src/Common/new_delete.cpp)
|
||||
|
||||
if(USE_RDKAFKA)
|
||||
add_headers_and_sources(dbms src/Storages/Kafka)
|
||||
@ -140,6 +140,9 @@ endif ()
|
||||
|
||||
add_library(clickhouse_common_io ${clickhouse_common_io_headers} ${clickhouse_common_io_sources})
|
||||
|
||||
add_library (clickhouse_malloc OBJECT src/Common/malloc.cpp)
|
||||
set_source_files_properties(src/Common/malloc.cpp PROPERTIES COMPILE_FLAGS "-fno-builtin")
|
||||
|
||||
add_library (clickhouse_new_delete STATIC src/Common/new_delete.cpp)
|
||||
target_link_libraries (clickhouse_new_delete PRIVATE clickhouse_common_io)
|
||||
|
||||
@ -376,6 +379,10 @@ if (USE_POCO_MONGODB)
|
||||
dbms_target_link_libraries (PRIVATE ${Poco_MongoDB_LIBRARY})
|
||||
endif()
|
||||
|
||||
if (USE_POCO_REDIS)
|
||||
dbms_target_link_libraries (PRIVATE ${Poco_Redis_LIBRARY})
|
||||
endif()
|
||||
|
||||
if (USE_POCO_NETSSL)
|
||||
target_link_libraries (clickhouse_common_io PRIVATE ${Poco_NetSSL_LIBRARY} ${Poco_Crypto_LIBRARY})
|
||||
dbms_target_link_libraries (PRIVATE ${Poco_NetSSL_LIBRARY} ${Poco_Crypto_LIBRARY})
|
||||
@ -428,6 +435,8 @@ if (USE_JEMALLOC)
|
||||
|
||||
if(NOT MAKE_STATIC_LIBRARIES AND ${JEMALLOC_LIBRARIES} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$")
|
||||
# mallctl in dbms/src/Interpreters/AsynchronousMetrics.cpp
|
||||
# Actually we link JEMALLOC to almost all libraries.
|
||||
# This is just hotfix for some uninvestigated problem.
|
||||
target_link_libraries(clickhouse_interpreters PRIVATE ${JEMALLOC_LIBRARIES})
|
||||
endif()
|
||||
endif ()
|
||||
|
@ -1,11 +1,11 @@
|
||||
# This strings autochanged from release_lib.sh:
|
||||
set(VERSION_REVISION 54429)
|
||||
set(VERSION_REVISION 54430)
|
||||
set(VERSION_MAJOR 19)
|
||||
set(VERSION_MINOR 18)
|
||||
set(VERSION_MINOR 19)
|
||||
set(VERSION_PATCH 1)
|
||||
set(VERSION_GITHASH 4e68211879480b637683ae66dbcc89a2714682af)
|
||||
set(VERSION_DESCRIBE v19.18.1.1-prestable)
|
||||
set(VERSION_STRING 19.18.1.1)
|
||||
set(VERSION_GITHASH 8bd9709d1dec3366e35d2efeab213435857f67a9)
|
||||
set(VERSION_DESCRIBE v19.19.1.1-prestable)
|
||||
set(VERSION_STRING 19.19.1.1)
|
||||
# end of autochange
|
||||
|
||||
set(VERSION_EXTRA "" CACHE STRING "")
|
||||
|
@ -34,7 +34,6 @@
|
||||
#include <IO/WriteBufferFromTemporaryFile.h>
|
||||
#include <DataStreams/IBlockInputStream.h>
|
||||
#include <Interpreters/executeQuery.h>
|
||||
#include <Interpreters/Quota.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Poco/Net/HTTPStream.h>
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromPocoSocket.h>
|
||||
#include <Storages/IStorage.h>
|
||||
#include <boost/algorithm/string/replace.hpp>
|
||||
|
||||
#if USE_POCO_NETSSL
|
||||
#include <Poco/Net/SecureStreamSocket.h>
|
||||
@ -220,7 +221,8 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
|
||||
{
|
||||
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
|
||||
auto user = connection_context.getUser(user_name);
|
||||
if (user->authentication.getType() != DB::Authentication::DOUBLE_SHA1_PASSWORD)
|
||||
const DB::Authentication::Type user_auth_type = user->authentication.getType();
|
||||
if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD)
|
||||
{
|
||||
authPluginSSL();
|
||||
}
|
||||
@ -267,29 +269,49 @@ void MySQLHandler::comPing()
|
||||
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
|
||||
static bool isFederatedServerSetupCommand(const String & query);
|
||||
|
||||
void MySQLHandler::comQuery(ReadBuffer & payload)
|
||||
{
|
||||
bool with_output = false;
|
||||
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
|
||||
with_output = true;
|
||||
};
|
||||
String query = String(payload.position(), payload.buffer().end());
|
||||
|
||||
const String query("select ''");
|
||||
ReadBufferFromString empty_select(query);
|
||||
|
||||
bool should_replace = false;
|
||||
// Translate query from MySQL to ClickHouse.
|
||||
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
|
||||
if (std::string(payload.position(), payload.buffer().end()) == "select @@version_comment limit 1") // MariaDB client starts session with that query
|
||||
// This is a workaround in order to support adding ClickHouse to MySQL using federated server.
|
||||
// As Clickhouse doesn't support these statements, we just send OK packet in response.
|
||||
if (isFederatedServerSetupCommand(query))
|
||||
{
|
||||
should_replace = true;
|
||||
}
|
||||
|
||||
Context query_context = connection_context;
|
||||
executeQuery(should_replace ? empty_select : payload, *out, true, query_context, set_content_type, nullptr);
|
||||
|
||||
if (!with_output)
|
||||
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool with_output = false;
|
||||
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
|
||||
with_output = true;
|
||||
};
|
||||
|
||||
String replacement_query = "select ''";
|
||||
bool should_replace = false;
|
||||
|
||||
// Translate query from MySQL to ClickHouse.
|
||||
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
|
||||
if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query
|
||||
{
|
||||
should_replace = true;
|
||||
}
|
||||
// This is a workaround in order to support adding ClickHouse to MySQL using federated server.
|
||||
if (0 == strncasecmp("SHOW TABLE STATUS LIKE", query.c_str(), 22))
|
||||
{
|
||||
should_replace = true;
|
||||
replacement_query = boost::replace_all_copy(query, "SHOW TABLE STATUS LIKE ", show_table_status_replacement_query);
|
||||
}
|
||||
|
||||
ReadBufferFromString replacement(replacement_query);
|
||||
|
||||
Context query_context = connection_context;
|
||||
executeQuery(should_replace ? replacement : payload, *out, true, query_context, set_content_type, nullptr);
|
||||
|
||||
if (!with_output)
|
||||
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLHandler::authPluginSSL()
|
||||
@ -335,4 +357,33 @@ void MySQLHandlerSSL::finishHandshakeSSL(size_t packet_size, char * buf, size_t
|
||||
|
||||
#endif
|
||||
|
||||
static bool isFederatedServerSetupCommand(const String & query)
|
||||
{
|
||||
return 0 == strncasecmp("SET NAMES", query.c_str(), 9) || 0 == strncasecmp("SET character_set_results", query.c_str(), 25)
|
||||
|| 0 == strncasecmp("SET FOREIGN_KEY_CHECKS", query.c_str(), 22) || 0 == strncasecmp("SET AUTOCOMMIT", query.c_str(), 14)
|
||||
|| 0 == strncasecmp("SET SESSION TRANSACTION ISOLATION LEVEL", query.c_str(), 39);
|
||||
}
|
||||
|
||||
const String MySQLHandler::show_table_status_replacement_query("SELECT"
|
||||
" name AS Name,"
|
||||
" engine AS Engine,"
|
||||
" '10' AS Version,"
|
||||
" 'Dynamic' AS Row_format,"
|
||||
" 0 AS Rows,"
|
||||
" 0 AS Avg_row_length,"
|
||||
" 0 AS Data_length,"
|
||||
" 0 AS Max_data_length,"
|
||||
" 0 AS Index_length,"
|
||||
" 0 AS Data_free,"
|
||||
" 'NULL' AS Auto_increment,"
|
||||
" metadata_modification_time AS Create_time,"
|
||||
" metadata_modification_time AS Update_time,"
|
||||
" metadata_modification_time AS Check_time,"
|
||||
" 'utf8_bin' AS Collation,"
|
||||
" 'NULL' AS Checksum,"
|
||||
" '' AS Create_options,"
|
||||
" '' AS Comment"
|
||||
" FROM system.tables"
|
||||
" WHERE name LIKE ");
|
||||
|
||||
}
|
||||
|
@ -11,7 +11,6 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
|
||||
class MySQLHandler : public Poco::Net::TCPServerConnection
|
||||
{
|
||||
@ -59,6 +58,9 @@ protected:
|
||||
std::shared_ptr<WriteBuffer> out;
|
||||
|
||||
bool secure_connection = false;
|
||||
|
||||
private:
|
||||
static const String show_table_status_replacement_query;
|
||||
};
|
||||
|
||||
#if USE_SSL && USE_POCO_NETSSL
|
||||
|
@ -243,6 +243,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
}
|
||||
#endif
|
||||
|
||||
global_context->setRemoteHostFilter(config());
|
||||
|
||||
std::string path = getCanonicalPath(config().getString("path", DBMS_DEFAULT_PATH));
|
||||
std::string default_database = config().getString("default_database", "default");
|
||||
|
||||
|
@ -19,7 +19,6 @@
|
||||
#include <DataStreams/NativeBlockInputStream.h>
|
||||
#include <DataStreams/NativeBlockOutputStream.h>
|
||||
#include <Interpreters/executeQuery.h>
|
||||
#include <Interpreters/Quota.h>
|
||||
#include <Interpreters/TablesStatus.h>
|
||||
#include <Interpreters/InternalTextLogsQueue.h>
|
||||
#include <Storages/StorageMemory.h>
|
||||
@ -201,6 +200,8 @@ void TCPHandler::runImpl()
|
||||
/// So, the stream has been marked as cancelled and we can't read from it anymore.
|
||||
state.block_in.reset();
|
||||
state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker.
|
||||
|
||||
state.temporary_tables_read = true;
|
||||
});
|
||||
|
||||
/// Send structure of columns to client for function input()
|
||||
@ -340,6 +341,18 @@ void TCPHandler::runImpl()
|
||||
LOG_WARNING(log, "Client has gone away.");
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (exception && !state.temporary_tables_read)
|
||||
query_context->initializeExternalTablesIfSet();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
network_error = true;
|
||||
LOG_WARNING(log, "Can't read external tables after query failure.");
|
||||
}
|
||||
|
||||
|
||||
try
|
||||
{
|
||||
query_scope.reset();
|
||||
|
@ -63,6 +63,8 @@ struct QueryState
|
||||
bool sent_all_data = false;
|
||||
/// Request requires data from the client (INSERT, but not INSERT SELECT).
|
||||
bool need_receive_data_for_insert = false;
|
||||
/// Temporary tables read
|
||||
bool temporary_tables_read = false;
|
||||
|
||||
/// Request requires data from client for function input()
|
||||
bool need_receive_data_for_input = false;
|
||||
|
@ -3,6 +3,25 @@
|
||||
NOTE: User and query level settings are set up in "users.xml" file.
|
||||
-->
|
||||
<yandex>
|
||||
<!-- The list of hosts allowed to use in URL-related storage engines and table functions.
|
||||
If this section is not present in configuration, all hosts are allowed.
|
||||
-->
|
||||
<remote_url_allow_hosts>
|
||||
<!-- Host should be specified exactly as in URL. The name is checked before DNS resolution.
|
||||
Example: "yandex.ru", "yandex.ru." and "www.yandex.ru" are different hosts.
|
||||
If port is explicitly specified in URL, the host:port is checked as a whole.
|
||||
If host specified here without port, any port with this host allowed.
|
||||
"yandex.ru" -> "yandex.ru:443", "yandex.ru:80" etc. is allowed, but "yandex.ru:80" -> only "yandex.ru:80" is allowed.
|
||||
If the host is specified as IP address, it is checked as specified in URL. Example: "[2a02:6b8:a::a]".
|
||||
If there are redirects and support for redirects is enabled, every redirect (the Location field) is checked.
|
||||
-->
|
||||
|
||||
<!-- Regular expression can be specified. RE2 engine is used for regexps.
|
||||
Regexps are not aligned: don't forget to add ^ and $. Also don't forget to escape dot (.) metacharacter
|
||||
(forgetting to do so is a common source of error).
|
||||
-->
|
||||
</remote_url_allow_hosts>
|
||||
|
||||
<logger>
|
||||
<!-- Possible levels: https://github.com/pocoproject/poco/blob/develop/Foundation/include/Poco/Logger.h#L105 -->
|
||||
<level>trace</level>
|
||||
@ -15,7 +34,6 @@
|
||||
<!--display_name>production</display_name--> <!-- It is the name that will be shown in the client -->
|
||||
<http_port>8123</http_port>
|
||||
<tcp_port>9000</tcp_port>
|
||||
|
||||
<!-- For HTTPS and SSL over native protocol. -->
|
||||
<!--
|
||||
<https_port>8443</https_port>
|
||||
|
52
dbms/src/Access/AccessControlManager.cpp
Normal file
52
dbms/src/Access/AccessControlManager.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <Access/AccessControlManager.h>
|
||||
#include <Access/MultipleAccessStorage.h>
|
||||
#include <Access/MemoryAccessStorage.h>
|
||||
#include <Access/UsersConfigAccessStorage.h>
|
||||
#include <Access/QuotaContextFactory.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace
|
||||
{
|
||||
std::vector<std::unique_ptr<IAccessStorage>> createStorages()
|
||||
{
|
||||
std::vector<std::unique_ptr<IAccessStorage>> list;
|
||||
list.emplace_back(std::make_unique<MemoryAccessStorage>());
|
||||
list.emplace_back(std::make_unique<UsersConfigAccessStorage>());
|
||||
return list;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
AccessControlManager::AccessControlManager()
|
||||
: MultipleAccessStorage(createStorages()),
|
||||
quota_context_factory(std::make_unique<QuotaContextFactory>(*this))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
AccessControlManager::~AccessControlManager()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguration & users_config)
|
||||
{
|
||||
auto & users_config_access_storage = dynamic_cast<UsersConfigAccessStorage &>(getStorageByIndex(1));
|
||||
users_config_access_storage.loadFromConfig(users_config);
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<QuotaContext> AccessControlManager::createQuotaContext(
|
||||
const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key)
|
||||
{
|
||||
return quota_context_factory->createContext(user_name, address, custom_quota_key);
|
||||
}
|
||||
|
||||
|
||||
std::vector<QuotaUsageInfo> AccessControlManager::getQuotaUsageInfo() const
|
||||
{
|
||||
return quota_context_factory->getUsageInfo();
|
||||
}
|
||||
}
|
45
dbms/src/Access/AccessControlManager.h
Normal file
45
dbms/src/Access/AccessControlManager.h
Normal file
@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/MultipleAccessStorage.h>
|
||||
#include <Poco/AutoPtr.h>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace Poco
|
||||
{
|
||||
namespace Net
|
||||
{
|
||||
class IPAddress;
|
||||
}
|
||||
namespace Util
|
||||
{
|
||||
class AbstractConfiguration;
|
||||
}
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class QuotaContext;
|
||||
class QuotaContextFactory;
|
||||
struct QuotaUsageInfo;
|
||||
|
||||
|
||||
/// Manages access control entities.
|
||||
class AccessControlManager : public MultipleAccessStorage
|
||||
{
|
||||
public:
|
||||
AccessControlManager();
|
||||
~AccessControlManager();
|
||||
|
||||
void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config);
|
||||
|
||||
std::shared_ptr<QuotaContext>
|
||||
createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key);
|
||||
|
||||
std::vector<QuotaUsageInfo> getQuotaUsageInfo() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<QuotaContextFactory> quota_context_factory;
|
||||
};
|
||||
|
||||
}
|
@ -160,6 +160,35 @@ void Authentication::setPasswordHashBinary(const Digest & hash)
|
||||
}
|
||||
|
||||
|
||||
Digest Authentication::getPasswordDoubleSHA1() const
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case NO_PASSWORD:
|
||||
{
|
||||
Poco::SHA1Engine engine;
|
||||
return engine.digest();
|
||||
}
|
||||
|
||||
case PLAINTEXT_PASSWORD:
|
||||
{
|
||||
Poco::SHA1Engine engine;
|
||||
engine.update(getPassword());
|
||||
const Digest & first_sha1 = engine.digest();
|
||||
engine.update(first_sha1.data(), first_sha1.size());
|
||||
return engine.digest();
|
||||
}
|
||||
|
||||
case SHA256_PASSWORD:
|
||||
throw Exception("Cannot get password double SHA1 for user with 'SHA256_PASSWORD' authentication.", ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
case DOUBLE_SHA1_PASSWORD:
|
||||
return password_hash;
|
||||
}
|
||||
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
|
||||
bool Authentication::isCorrectPassword(const String & password_) const
|
||||
{
|
||||
switch (type)
|
||||
@ -168,7 +197,14 @@ bool Authentication::isCorrectPassword(const String & password_) const
|
||||
return true;
|
||||
|
||||
case PLAINTEXT_PASSWORD:
|
||||
return password_ == StringRef{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()};
|
||||
{
|
||||
if (password_ == StringRef{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
|
||||
return true;
|
||||
|
||||
// For compatibility with MySQL clients which support only native authentication plugin, SHA1 can be passed instead of password.
|
||||
auto password_sha1 = encodeSHA1(password_hash);
|
||||
return password_ == StringRef{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
|
||||
}
|
||||
|
||||
case SHA256_PASSWORD:
|
||||
return encodeSHA256(password_) == password_hash;
|
||||
|
@ -49,6 +49,10 @@ public:
|
||||
void setPasswordHashBinary(const Digest & hash);
|
||||
const Digest & getPasswordHashBinary() const { return password_hash; }
|
||||
|
||||
/// Returns SHA1(SHA1(password)) used by MySQL compatibility server for authentication.
|
||||
/// Allowed to use for Type::NO_PASSWORD, Type::PLAINTEXT_PASSWORD, Type::DOUBLE_SHA1_PASSWORD.
|
||||
Digest getPasswordDoubleSHA1() const;
|
||||
|
||||
/// Checks if the provided password is correct. Returns false if not.
|
||||
bool isCorrectPassword(const String & password) const;
|
||||
|
||||
|
19
dbms/src/Access/IAccessEntity.cpp
Normal file
19
dbms/src/Access/IAccessEntity.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
#include <Access/IAccessEntity.h>
|
||||
#include <Access/Quota.h>
|
||||
#include <common/demangle.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
String IAccessEntity::getTypeName(std::type_index type)
|
||||
{
|
||||
if (type == typeid(Quota))
|
||||
return "Quota";
|
||||
return demangle(type.name());
|
||||
}
|
||||
|
||||
bool IAccessEntity::equal(const IAccessEntity & other) const
|
||||
{
|
||||
return (full_name == other.full_name) && (getType() == other.getType());
|
||||
}
|
||||
}
|
49
dbms/src/Access/IAccessEntity.h
Normal file
49
dbms/src/Access/IAccessEntity.h
Normal file
@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <memory>
|
||||
#include <typeindex>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Access entity is a set of data which have a name and a type. Access entity control something related to the access control.
|
||||
/// Entities can be stored to a file or another storage, see IAccessStorage.
|
||||
struct IAccessEntity
|
||||
{
|
||||
IAccessEntity() = default;
|
||||
IAccessEntity(const IAccessEntity &) = default;
|
||||
virtual ~IAccessEntity() = default;
|
||||
virtual std::shared_ptr<IAccessEntity> clone() const = 0;
|
||||
|
||||
std::type_index getType() const { return typeid(*this); }
|
||||
static String getTypeName(std::type_index type);
|
||||
const String getTypeName() const { return getTypeName(getType()); }
|
||||
|
||||
template <typename EntityType>
|
||||
bool isTypeOf() const { return isTypeOf(typeid(EntityType)); }
|
||||
bool isTypeOf(std::type_index type) const { return type == getType(); }
|
||||
|
||||
virtual void setName(const String & name_) { full_name = name_; }
|
||||
virtual String getName() const { return full_name; }
|
||||
String getFullName() const { return full_name; }
|
||||
|
||||
friend bool operator ==(const IAccessEntity & lhs, const IAccessEntity & rhs) { return lhs.equal(rhs); }
|
||||
friend bool operator !=(const IAccessEntity & lhs, const IAccessEntity & rhs) { return !(lhs == rhs); }
|
||||
|
||||
protected:
|
||||
String full_name;
|
||||
|
||||
virtual bool equal(const IAccessEntity & other) const;
|
||||
|
||||
/// Helper function to define clone() in the derived classes.
|
||||
template <typename EntityType>
|
||||
std::shared_ptr<IAccessEntity> cloneImpl() const
|
||||
{
|
||||
return std::make_shared<EntityType>(typeid_cast<const EntityType &>(*this));
|
||||
}
|
||||
};
|
||||
|
||||
using AccessEntityPtr = std::shared_ptr<const IAccessEntity>;
|
||||
}
|
450
dbms/src/Access/IAccessStorage.cpp
Normal file
450
dbms/src/Access/IAccessStorage.cpp
Normal file
@ -0,0 +1,450 @@
|
||||
#include <Access/IAccessStorage.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/quoteString.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Poco/UUIDGenerator.h>
|
||||
#include <Poco/Logger.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_CAST;
|
||||
extern const int ACCESS_ENTITY_NOT_FOUND;
|
||||
extern const int ACCESS_ENTITY_ALREADY_EXISTS;
|
||||
extern const int ACCESS_ENTITY_FOUND_DUPLICATES;
|
||||
extern const int ACCESS_ENTITY_STORAGE_READONLY;
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::findAll(std::type_index type) const
|
||||
{
|
||||
return findAllImpl(type);
|
||||
}
|
||||
|
||||
|
||||
std::optional<UUID> IAccessStorage::find(std::type_index type, const String & name) const
|
||||
{
|
||||
return findImpl(type, name);
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::find(std::type_index type, const Strings & names) const
|
||||
{
|
||||
std::vector<UUID> ids;
|
||||
ids.reserve(names.size());
|
||||
for (const String & name : names)
|
||||
{
|
||||
auto id = findImpl(type, name);
|
||||
if (id)
|
||||
ids.push_back(*id);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
|
||||
UUID IAccessStorage::getID(std::type_index type, const String & name) const
|
||||
{
|
||||
auto id = findImpl(type, name);
|
||||
if (id)
|
||||
return *id;
|
||||
throwNotFound(type, name);
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::getIDs(std::type_index type, const Strings & names) const
|
||||
{
|
||||
std::vector<UUID> ids;
|
||||
ids.reserve(names.size());
|
||||
for (const String & name : names)
|
||||
ids.push_back(getID(type, name));
|
||||
return ids;
|
||||
}
|
||||
|
||||
|
||||
bool IAccessStorage::exists(const UUID & id) const
|
||||
{
|
||||
return existsImpl(id);
|
||||
}
|
||||
|
||||
|
||||
|
||||
AccessEntityPtr IAccessStorage::tryReadBase(const UUID & id) const
|
||||
{
|
||||
try
|
||||
{
|
||||
return readImpl(id);
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
String IAccessStorage::readName(const UUID & id) const
|
||||
{
|
||||
return readNameImpl(id);
|
||||
}
|
||||
|
||||
|
||||
std::optional<String> IAccessStorage::tryReadName(const UUID & id) const
|
||||
{
|
||||
try
|
||||
{
|
||||
return readNameImpl(id);
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
UUID IAccessStorage::insert(const AccessEntityPtr & entity)
|
||||
{
|
||||
return insertImpl(entity, false);
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::insert(const std::vector<AccessEntityPtr> & multiple_entities)
|
||||
{
|
||||
std::vector<UUID> ids;
|
||||
ids.reserve(multiple_entities.size());
|
||||
String error_message;
|
||||
for (const auto & entity : multiple_entities)
|
||||
{
|
||||
try
|
||||
{
|
||||
ids.push_back(insertImpl(entity, false));
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
if (e.code() != ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS)
|
||||
throw;
|
||||
error_message += (error_message.empty() ? "" : ". ") + e.message();
|
||||
}
|
||||
}
|
||||
if (!error_message.empty())
|
||||
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
|
||||
return ids;
|
||||
}
|
||||
|
||||
|
||||
std::optional<UUID> IAccessStorage::tryInsert(const AccessEntityPtr & entity)
|
||||
{
|
||||
try
|
||||
{
|
||||
return insertImpl(entity, false);
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::tryInsert(const std::vector<AccessEntityPtr> & multiple_entities)
|
||||
{
|
||||
std::vector<UUID> ids;
|
||||
ids.reserve(multiple_entities.size());
|
||||
for (const auto & entity : multiple_entities)
|
||||
{
|
||||
try
|
||||
{
|
||||
ids.push_back(insertImpl(entity, false));
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
|
||||
UUID IAccessStorage::insertOrReplace(const AccessEntityPtr & entity)
|
||||
{
|
||||
return insertImpl(entity, true);
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities)
|
||||
{
|
||||
std::vector<UUID> ids;
|
||||
ids.reserve(multiple_entities.size());
|
||||
for (const auto & entity : multiple_entities)
|
||||
ids.push_back(insertImpl(entity, true));
|
||||
return ids;
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::remove(const UUID & id)
|
||||
{
|
||||
removeImpl(id);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::remove(const std::vector<UUID> & ids)
|
||||
{
|
||||
String error_message;
|
||||
for (const auto & id : ids)
|
||||
{
|
||||
try
|
||||
{
|
||||
removeImpl(id);
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND)
|
||||
throw;
|
||||
error_message += (error_message.empty() ? "" : ". ") + e.message();
|
||||
}
|
||||
}
|
||||
if (!error_message.empty())
|
||||
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
|
||||
}
|
||||
|
||||
|
||||
bool IAccessStorage::tryRemove(const UUID & id)
|
||||
{
|
||||
try
|
||||
{
|
||||
removeImpl(id);
|
||||
return true;
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::tryRemove(const std::vector<UUID> & ids)
|
||||
{
|
||||
std::vector<UUID> removed;
|
||||
removed.reserve(ids.size());
|
||||
for (const auto & id : ids)
|
||||
{
|
||||
try
|
||||
{
|
||||
removeImpl(id);
|
||||
removed.push_back(id);
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
}
|
||||
}
|
||||
return removed;
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::update(const UUID & id, const UpdateFunc & update_func)
|
||||
{
|
||||
updateImpl(id, update_func);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & update_func)
|
||||
{
|
||||
String error_message;
|
||||
for (const auto & id : ids)
|
||||
{
|
||||
try
|
||||
{
|
||||
updateImpl(id, update_func);
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND)
|
||||
throw;
|
||||
error_message += (error_message.empty() ? "" : ". ") + e.message();
|
||||
}
|
||||
}
|
||||
if (!error_message.empty())
|
||||
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
|
||||
}
|
||||
|
||||
|
||||
bool IAccessStorage::tryUpdate(const UUID & id, const UpdateFunc & update_func)
|
||||
{
|
||||
try
|
||||
{
|
||||
updateImpl(id, update_func);
|
||||
return true;
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> IAccessStorage::tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func)
|
||||
{
|
||||
std::vector<UUID> updated;
|
||||
updated.reserve(ids.size());
|
||||
for (const auto & id : ids)
|
||||
{
|
||||
try
|
||||
{
|
||||
updateImpl(id, update_func);
|
||||
updated.push_back(id);
|
||||
}
|
||||
catch (Exception &)
|
||||
{
|
||||
}
|
||||
}
|
||||
return updated;
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const
|
||||
{
|
||||
return subscribeForChangesImpl(type, handler);
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
|
||||
{
|
||||
return subscribeForChangesImpl(id, handler);
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
|
||||
{
|
||||
if (ids.empty())
|
||||
return nullptr;
|
||||
if (ids.size() == 1)
|
||||
return subscribeForChangesImpl(ids[0], handler);
|
||||
|
||||
std::vector<SubscriptionPtr> subscriptions;
|
||||
subscriptions.reserve(ids.size());
|
||||
for (const auto & id : ids)
|
||||
{
|
||||
auto subscription = subscribeForChangesImpl(id, handler);
|
||||
if (subscription)
|
||||
subscriptions.push_back(std::move(subscription));
|
||||
}
|
||||
|
||||
class SubscriptionImpl : public Subscription
|
||||
{
|
||||
public:
|
||||
SubscriptionImpl(std::vector<SubscriptionPtr> subscriptions_)
|
||||
: subscriptions(std::move(subscriptions_)) {}
|
||||
private:
|
||||
std::vector<SubscriptionPtr> subscriptions;
|
||||
};
|
||||
|
||||
return std::make_unique<SubscriptionImpl>(std::move(subscriptions));
|
||||
}
|
||||
|
||||
|
||||
bool IAccessStorage::hasSubscription(std::type_index type) const
|
||||
{
|
||||
return hasSubscriptionImpl(type);
|
||||
}
|
||||
|
||||
|
||||
bool IAccessStorage::hasSubscription(const UUID & id) const
|
||||
{
|
||||
return hasSubscriptionImpl(id);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::notify(const Notifications & notifications)
|
||||
{
|
||||
for (const auto & [fn, id, new_entity] : notifications)
|
||||
fn(id, new_entity);
|
||||
}
|
||||
|
||||
|
||||
UUID IAccessStorage::generateRandomID()
|
||||
{
|
||||
static Poco::UUIDGenerator generator;
|
||||
UUID id;
|
||||
generator.createRandom().copyTo(reinterpret_cast<char *>(&id));
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
Poco::Logger * IAccessStorage::getLogger() const
|
||||
{
|
||||
Poco::Logger * ptr = log.load();
|
||||
if (!ptr)
|
||||
log.store(ptr = &Poco::Logger::get("Access(" + storage_name + ")"), std::memory_order_relaxed);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwNotFound(const UUID & id) const
|
||||
{
|
||||
throw Exception("ID {" + toString(id) + "} not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwNotFound(std::type_index type, const String & name) const
|
||||
{
|
||||
throw Exception(
|
||||
getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) const
|
||||
{
|
||||
throw Exception(
|
||||
"ID {" + toString(id) + "}: " + getTypeName(type) + backQuote(name) + " expected to be of type " + getTypeName(required_type),
|
||||
ErrorCodes::BAD_CAST);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const
|
||||
{
|
||||
throw Exception(
|
||||
getTypeName(type) + " " + backQuote(name) + ": cannot insert because the ID {" + toString(id) + "} is already used by "
|
||||
+ getTypeName(existing_type) + " " + backQuote(existing_name) + " in " + getStorageName(),
|
||||
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwNameCollisionCannotInsert(std::type_index type, const String & name) const
|
||||
{
|
||||
throw Exception(
|
||||
getTypeName(type) + " " + backQuote(name) + ": cannot insert because " + getTypeName(type) + " " + backQuote(name)
|
||||
+ " already exists in " + getStorageName(),
|
||||
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const
|
||||
{
|
||||
throw Exception(
|
||||
getTypeName(type) + " " + backQuote(old_name) + ": cannot rename to " + backQuote(new_name) + " because " + getTypeName(type) + " "
|
||||
+ backQuote(new_name) + " already exists in " + getStorageName(),
|
||||
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwReadonlyCannotInsert(std::type_index type, const String & name) const
|
||||
{
|
||||
throw Exception(
|
||||
"Cannot insert " + getTypeName(type) + " " + backQuote(name) + " to " + getStorageName() + " because this storage is readonly",
|
||||
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwReadonlyCannotUpdate(std::type_index type, const String & name) const
|
||||
{
|
||||
throw Exception(
|
||||
"Cannot update " + getTypeName(type) + " " + backQuote(name) + " in " + getStorageName() + " because this storage is readonly",
|
||||
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
|
||||
}
|
||||
|
||||
|
||||
void IAccessStorage::throwReadonlyCannotRemove(std::type_index type, const String & name) const
|
||||
{
|
||||
throw Exception(
|
||||
"Cannot remove " + getTypeName(type) + " " + backQuote(name) + " from " + getStorageName() + " because this storage is readonly",
|
||||
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
|
||||
}
|
||||
}
|
209
dbms/src/Access/IAccessStorage.h
Normal file
209
dbms/src/Access/IAccessStorage.h
Normal file
@ -0,0 +1,209 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/IAccessEntity.h>
|
||||
#include <Core/Types.h>
|
||||
#include <Core/UUID.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
|
||||
|
||||
namespace Poco { class Logger; }
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Contains entities, i.e. instances of classes derived from IAccessEntity.
|
||||
/// The implementations of this class MUST be thread-safe.
|
||||
class IAccessStorage
|
||||
{
|
||||
public:
|
||||
IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {}
|
||||
virtual ~IAccessStorage() {}
|
||||
|
||||
/// Returns the name of this storage.
|
||||
const String & getStorageName() const { return storage_name; }
|
||||
|
||||
/// Returns the identifiers of all the entities of a specified type contained in the storage.
|
||||
std::vector<UUID> findAll(std::type_index type) const;
|
||||
|
||||
template <typename EntityType>
|
||||
std::vector<UUID> findAll() const { return findAll(typeid(EntityType)); }
|
||||
|
||||
/// Searchs for an entity with specified type and name. Returns std::nullopt if not found.
|
||||
std::optional<UUID> find(std::type_index type, const String & name) const;
|
||||
|
||||
template <typename EntityType>
|
||||
std::optional<UUID> find(const String & name) const { return find(typeid(EntityType), name); }
|
||||
|
||||
std::vector<UUID> find(std::type_index type, const Strings & names) const;
|
||||
|
||||
template <typename EntityType>
|
||||
std::vector<UUID> find(const Strings & names) const { return find(typeid(EntityType), names); }
|
||||
|
||||
/// Searchs for an entity with specified name and type. Throws an exception if not found.
|
||||
UUID getID(std::type_index type, const String & name) const;
|
||||
|
||||
template <typename EntityType>
|
||||
UUID getID(const String & name) const { return getID(typeid(EntityType), name); }
|
||||
|
||||
std::vector<UUID> getIDs(std::type_index type, const Strings & names) const;
|
||||
|
||||
template <typename EntityType>
|
||||
std::vector<UUID> getIDs(const Strings & names) const { return getIDs(typeid(EntityType), names); }
|
||||
|
||||
/// Returns whether there is an entity with such identifier in the storage.
|
||||
bool exists(const UUID & id) const;
|
||||
|
||||
/// Reads an entity. Throws an exception if not found.
|
||||
template <typename EntityType = IAccessEntity>
|
||||
std::shared_ptr<const EntityType> read(const UUID & id) const;
|
||||
|
||||
template <typename EntityType = IAccessEntity>
|
||||
std::shared_ptr<const EntityType> read(const String & name) const;
|
||||
|
||||
/// Reads an entity. Returns nullptr if not found.
|
||||
template <typename EntityType = IAccessEntity>
|
||||
std::shared_ptr<const EntityType> tryRead(const UUID & id) const;
|
||||
|
||||
template <typename EntityType = IAccessEntity>
|
||||
std::shared_ptr<const EntityType> tryRead(const String & name) const;
|
||||
|
||||
/// Reads only name of an entity.
|
||||
String readName(const UUID & id) const;
|
||||
std::optional<String> tryReadName(const UUID & id) const;
|
||||
|
||||
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
|
||||
/// Throws an exception if the specified name already exists.
|
||||
UUID insert(const AccessEntityPtr & entity);
|
||||
std::vector<UUID> insert(const std::vector<AccessEntityPtr> & multiple_entities);
|
||||
|
||||
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
|
||||
std::optional<UUID> tryInsert(const AccessEntityPtr & entity);
|
||||
std::vector<UUID> tryInsert(const std::vector<AccessEntityPtr> & multiple_entities);
|
||||
|
||||
/// Inserts an entity to the storage. Return ID of a new entry in the storage.
|
||||
/// Replaces an existing entry in the storage if the specified name already exists.
|
||||
UUID insertOrReplace(const AccessEntityPtr & entity);
|
||||
std::vector<UUID> insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities);
|
||||
|
||||
/// Removes an entity from the storage. Throws an exception if couldn't remove.
|
||||
void remove(const UUID & id);
|
||||
void remove(const std::vector<UUID> & ids);
|
||||
|
||||
/// Removes an entity from the storage. Returns false if couldn't remove.
|
||||
bool tryRemove(const UUID & id);
|
||||
|
||||
/// Removes multiple entities from the storage. Returns the list of successfully dropped.
|
||||
std::vector<UUID> tryRemove(const std::vector<UUID> & ids);
|
||||
|
||||
using UpdateFunc = std::function<AccessEntityPtr(const AccessEntityPtr &)>;
|
||||
|
||||
/// Updates an entity stored in the storage. Throws an exception if couldn't update.
|
||||
void update(const UUID & id, const UpdateFunc & update_func);
|
||||
void update(const std::vector<UUID> & ids, const UpdateFunc & update_func);
|
||||
|
||||
/// Updates an entity stored in the storage. Returns false if couldn't update.
|
||||
bool tryUpdate(const UUID & id, const UpdateFunc & update_func);
|
||||
|
||||
/// Updates multiple entities in the storage. Returns the list of successfully updated.
|
||||
std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func);
|
||||
|
||||
class Subscription
|
||||
{
|
||||
public:
|
||||
virtual ~Subscription() {}
|
||||
};
|
||||
|
||||
using SubscriptionPtr = std::unique_ptr<Subscription>;
|
||||
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
|
||||
|
||||
/// Subscribes for all changes.
|
||||
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
|
||||
SubscriptionPtr subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const;
|
||||
|
||||
template <typename EntityType>
|
||||
SubscriptionPtr subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(typeid(EntityType), handler); }
|
||||
|
||||
/// Subscribes for changes of a specific entry.
|
||||
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
|
||||
SubscriptionPtr subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
|
||||
SubscriptionPtr subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
|
||||
|
||||
bool hasSubscription(std::type_index type) const;
|
||||
bool hasSubscription(const UUID & id) const;
|
||||
|
||||
protected:
|
||||
virtual std::optional<UUID> findImpl(std::type_index type, const String & name) const = 0;
|
||||
virtual std::vector<UUID> findAllImpl(std::type_index type) const = 0;
|
||||
virtual bool existsImpl(const UUID & id) const = 0;
|
||||
virtual AccessEntityPtr readImpl(const UUID & id) const = 0;
|
||||
virtual String readNameImpl(const UUID & id) const = 0;
|
||||
virtual UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) = 0;
|
||||
virtual void removeImpl(const UUID & id) = 0;
|
||||
virtual void updateImpl(const UUID & id, const UpdateFunc & update_func) = 0;
|
||||
virtual SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0;
|
||||
virtual SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const = 0;
|
||||
virtual bool hasSubscriptionImpl(const UUID & id) const = 0;
|
||||
virtual bool hasSubscriptionImpl(std::type_index type) const = 0;
|
||||
|
||||
static UUID generateRandomID();
|
||||
Poco::Logger * getLogger() const;
|
||||
static String getTypeName(std::type_index type) { return IAccessEntity::getTypeName(type); }
|
||||
[[noreturn]] void throwNotFound(const UUID & id) const;
|
||||
[[noreturn]] void throwNotFound(std::type_index type, const String & name) const;
|
||||
[[noreturn]] void throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) const;
|
||||
[[noreturn]] void throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const;
|
||||
[[noreturn]] void throwNameCollisionCannotInsert(std::type_index type, const String & name) const;
|
||||
[[noreturn]] void throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const;
|
||||
[[noreturn]] void throwReadonlyCannotInsert(std::type_index type, const String & name) const;
|
||||
[[noreturn]] void throwReadonlyCannotUpdate(std::type_index type, const String & name) const;
|
||||
[[noreturn]] void throwReadonlyCannotRemove(std::type_index type, const String & name) const;
|
||||
|
||||
using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>;
|
||||
using Notifications = std::vector<Notification>;
|
||||
static void notify(const Notifications & notifications);
|
||||
|
||||
private:
|
||||
AccessEntityPtr tryReadBase(const UUID & id) const;
|
||||
|
||||
const String storage_name;
|
||||
mutable std::atomic<Poco::Logger *> log = nullptr;
|
||||
};
|
||||
|
||||
|
||||
template <typename EntityType>
|
||||
std::shared_ptr<const EntityType> IAccessStorage::read(const UUID & id) const
|
||||
{
|
||||
auto entity = readImpl(id);
|
||||
auto ptr = typeid_cast<std::shared_ptr<const EntityType>>(entity);
|
||||
if (ptr)
|
||||
return ptr;
|
||||
throwBadCast(id, entity->getType(), entity->getFullName(), typeid(EntityType));
|
||||
}
|
||||
|
||||
|
||||
template <typename EntityType>
|
||||
std::shared_ptr<const EntityType> IAccessStorage::read(const String & name) const
|
||||
{
|
||||
return read<EntityType>(getID<EntityType>(name));
|
||||
}
|
||||
|
||||
|
||||
template <typename EntityType>
|
||||
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const UUID & id) const
|
||||
{
|
||||
auto entity = tryReadBase(id);
|
||||
if (!entity)
|
||||
return nullptr;
|
||||
return typeid_cast<std::shared_ptr<const EntityType>>(entity);
|
||||
}
|
||||
|
||||
|
||||
template <typename EntityType>
|
||||
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const String & name) const
|
||||
{
|
||||
auto id = find<EntityType>(name);
|
||||
return id ? tryRead<EntityType>(*id) : nullptr;
|
||||
}
|
||||
}
|
358
dbms/src/Access/MemoryAccessStorage.cpp
Normal file
358
dbms/src/Access/MemoryAccessStorage.cpp
Normal file
@ -0,0 +1,358 @@
|
||||
#include <Access/MemoryAccessStorage.h>
|
||||
#include <ext/scope_guard.h>
|
||||
#include <unordered_set>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_)
|
||||
: IAccessStorage(storage_name_), shared_ptr_to_this{std::make_shared<const MemoryAccessStorage *>(this)}
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
MemoryAccessStorage::~MemoryAccessStorage() {}
|
||||
|
||||
|
||||
std::optional<UUID> MemoryAccessStorage::findImpl(std::type_index type, const String & name) const
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
auto it = names.find({name, type});
|
||||
if (it == names.end())
|
||||
return {};
|
||||
|
||||
Entry & entry = *(it->second);
|
||||
return entry.id;
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> MemoryAccessStorage::findAllImpl(std::type_index type) const
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
std::vector<UUID> result;
|
||||
result.reserve(entries.size());
|
||||
for (const auto & [id, entry] : entries)
|
||||
if (entry.entity->isTypeOf(type))
|
||||
result.emplace_back(id);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
bool MemoryAccessStorage::existsImpl(const UUID & id) const
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
return entries.count(id);
|
||||
}
|
||||
|
||||
|
||||
AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id) const
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
auto it = entries.find(id);
|
||||
if (it == entries.end())
|
||||
throwNotFound(id);
|
||||
const Entry & entry = it->second;
|
||||
return entry.entity;
|
||||
}
|
||||
|
||||
|
||||
String MemoryAccessStorage::readNameImpl(const UUID & id) const
|
||||
{
|
||||
return readImpl(id)->getFullName();
|
||||
}
|
||||
|
||||
|
||||
UUID MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists)
|
||||
{
|
||||
Notifications notifications;
|
||||
SCOPE_EXIT({ notify(notifications); });
|
||||
|
||||
UUID id = generateRandomID();
|
||||
std::lock_guard lock{mutex};
|
||||
insertNoLock(generateRandomID(), new_entity, replace_if_exists, notifications);
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications)
|
||||
{
|
||||
const String & name = new_entity->getFullName();
|
||||
std::type_index type = new_entity->getType();
|
||||
|
||||
/// Check that we can insert.
|
||||
auto it = entries.find(id);
|
||||
if (it != entries.end())
|
||||
{
|
||||
const auto & existing_entry = it->second;
|
||||
throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getFullName());
|
||||
}
|
||||
|
||||
auto it2 = names.find({name, type});
|
||||
if (it2 != names.end())
|
||||
{
|
||||
const auto & existing_entry = *(it2->second);
|
||||
if (replace_if_exists)
|
||||
removeNoLock(existing_entry.id, notifications);
|
||||
else
|
||||
throwNameCollisionCannotInsert(type, name);
|
||||
}
|
||||
|
||||
/// Do insertion.
|
||||
auto & entry = entries[id];
|
||||
entry.id = id;
|
||||
entry.entity = new_entity;
|
||||
names[std::pair{name, type}] = &entry;
|
||||
prepareNotifications(entry, false, notifications);
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::removeImpl(const UUID & id)
|
||||
{
|
||||
Notifications notifications;
|
||||
SCOPE_EXIT({ notify(notifications); });
|
||||
|
||||
std::lock_guard lock{mutex};
|
||||
removeNoLock(id, notifications);
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::removeNoLock(const UUID & id, Notifications & notifications)
|
||||
{
|
||||
auto it = entries.find(id);
|
||||
if (it == entries.end())
|
||||
throwNotFound(id);
|
||||
|
||||
Entry & entry = it->second;
|
||||
const String & name = entry.entity->getFullName();
|
||||
std::type_index type = entry.entity->getType();
|
||||
|
||||
prepareNotifications(entry, true, notifications);
|
||||
|
||||
/// Do removing.
|
||||
names.erase({name, type});
|
||||
entries.erase(it);
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func)
|
||||
{
|
||||
Notifications notifications;
|
||||
SCOPE_EXIT({ notify(notifications); });
|
||||
|
||||
std::lock_guard lock{mutex};
|
||||
updateNoLock(id, update_func, notifications);
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications)
|
||||
{
|
||||
auto it = entries.find(id);
|
||||
if (it == entries.end())
|
||||
throwNotFound(id);
|
||||
|
||||
Entry & entry = it->second;
|
||||
auto old_entity = entry.entity;
|
||||
auto new_entity = update_func(old_entity);
|
||||
|
||||
if (*new_entity == *old_entity)
|
||||
return;
|
||||
|
||||
entry.entity = new_entity;
|
||||
|
||||
if (new_entity->getFullName() != old_entity->getFullName())
|
||||
{
|
||||
auto it2 = names.find({new_entity->getFullName(), new_entity->getType()});
|
||||
if (it2 != names.end())
|
||||
throwNameCollisionCannotRename(old_entity->getType(), old_entity->getFullName(), new_entity->getFullName());
|
||||
|
||||
names.erase({old_entity->getFullName(), old_entity->getType()});
|
||||
names[std::pair{new_entity->getFullName(), new_entity->getType()}] = &entry;
|
||||
}
|
||||
|
||||
prepareNotifications(entry, false, notifications);
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::setAll(const std::vector<AccessEntityPtr> & all_entities)
|
||||
{
|
||||
std::vector<std::pair<UUID, AccessEntityPtr>> entities_with_ids;
|
||||
entities_with_ids.reserve(all_entities.size());
|
||||
for (const auto & entity : all_entities)
|
||||
entities_with_ids.emplace_back(generateRandomID(), entity);
|
||||
setAll(entities_with_ids);
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities)
|
||||
{
|
||||
Notifications notifications;
|
||||
SCOPE_EXIT({ notify(notifications); });
|
||||
|
||||
std::lock_guard lock{mutex};
|
||||
setAllNoLock(all_entities, notifications);
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications)
|
||||
{
|
||||
/// Get list of the currently used IDs. Later we will remove those of them which are not used anymore.
|
||||
std::unordered_set<UUID> not_used_ids;
|
||||
for (const auto & id_and_entry : entries)
|
||||
not_used_ids.emplace(id_and_entry.first);
|
||||
|
||||
/// Remove conflicting entities.
|
||||
for (const auto & [id, entity] : all_entities)
|
||||
{
|
||||
auto it = entries.find(id);
|
||||
if (it != entries.end())
|
||||
{
|
||||
not_used_ids.erase(id); /// ID is used.
|
||||
Entry & entry = it->second;
|
||||
if (entry.entity->getType() != entity->getType())
|
||||
{
|
||||
removeNoLock(id, notifications);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto it2 = names.find({entity->getFullName(), entity->getType()});
|
||||
if (it2 != names.end())
|
||||
{
|
||||
Entry & entry = *(it2->second);
|
||||
if (entry.id != id)
|
||||
removeNoLock(id, notifications);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove entities which are not used anymore.
|
||||
for (const auto & id : not_used_ids)
|
||||
removeNoLock(id, notifications);
|
||||
|
||||
/// Insert or update entities.
|
||||
for (const auto & [id, entity] : all_entities)
|
||||
{
|
||||
auto it = entries.find(id);
|
||||
if (it != entries.end())
|
||||
{
|
||||
if (*(it->second.entity) != *entity)
|
||||
{
|
||||
const AccessEntityPtr & changed_entity = entity;
|
||||
updateNoLock(id, [&changed_entity](const AccessEntityPtr &) { return changed_entity; }, notifications);
|
||||
}
|
||||
}
|
||||
else
|
||||
insertNoLock(id, entity, false, notifications);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
|
||||
{
|
||||
for (const auto & handler : entry.handlers_by_id)
|
||||
notifications.push_back({handler, entry.id, remove ? nullptr : entry.entity});
|
||||
|
||||
auto range = handlers_by_type.equal_range(entry.entity->getType());
|
||||
for (auto it = range.first; it != range.second; ++it)
|
||||
notifications.push_back({it->second, entry.id, remove ? nullptr : entry.entity});
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr MemoryAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
|
||||
{
|
||||
class SubscriptionImpl : public Subscription
|
||||
{
|
||||
public:
|
||||
SubscriptionImpl(
|
||||
const MemoryAccessStorage & storage_,
|
||||
std::type_index type_,
|
||||
const OnChangedHandler & handler_)
|
||||
: storage_weak(storage_.shared_ptr_to_this)
|
||||
{
|
||||
std::lock_guard lock{storage_.mutex};
|
||||
handler_it = storage_.handlers_by_type.emplace(type_, handler_);
|
||||
}
|
||||
|
||||
~SubscriptionImpl() override
|
||||
{
|
||||
auto storage = storage_weak.lock();
|
||||
if (storage)
|
||||
{
|
||||
std::lock_guard lock{(*storage)->mutex};
|
||||
(*storage)->handlers_by_type.erase(handler_it);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::weak_ptr<const MemoryAccessStorage *> storage_weak;
|
||||
std::unordered_multimap<std::type_index, OnChangedHandler>::iterator handler_it;
|
||||
};
|
||||
|
||||
return std::make_unique<SubscriptionImpl>(*this, type, handler);
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
|
||||
{
|
||||
class SubscriptionImpl : public Subscription
|
||||
{
|
||||
public:
|
||||
SubscriptionImpl(
|
||||
const MemoryAccessStorage & storage_,
|
||||
const UUID & id_,
|
||||
const OnChangedHandler & handler_)
|
||||
: storage_weak(storage_.shared_ptr_to_this),
|
||||
id(id_)
|
||||
{
|
||||
std::lock_guard lock{storage_.mutex};
|
||||
auto it = storage_.entries.find(id);
|
||||
if (it == storage_.entries.end())
|
||||
{
|
||||
storage_weak.reset();
|
||||
return;
|
||||
}
|
||||
const Entry & entry = it->second;
|
||||
handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler_);
|
||||
}
|
||||
|
||||
~SubscriptionImpl() override
|
||||
{
|
||||
auto storage = storage_weak.lock();
|
||||
if (storage)
|
||||
{
|
||||
std::lock_guard lock{(*storage)->mutex};
|
||||
auto it = (*storage)->entries.find(id);
|
||||
if (it != (*storage)->entries.end())
|
||||
{
|
||||
const Entry & entry = it->second;
|
||||
entry.handlers_by_id.erase(handler_it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::weak_ptr<const MemoryAccessStorage *> storage_weak;
|
||||
UUID id;
|
||||
std::list<OnChangedHandler>::iterator handler_it;
|
||||
};
|
||||
|
||||
return std::make_unique<SubscriptionImpl>(*this, id, handler);
|
||||
}
|
||||
|
||||
|
||||
bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const
|
||||
{
|
||||
auto it = entries.find(id);
|
||||
if (it != entries.end())
|
||||
{
|
||||
const Entry & entry = it->second;
|
||||
return !entry.handlers_by_id.empty();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool MemoryAccessStorage::hasSubscriptionImpl(std::type_index type) const
|
||||
{
|
||||
auto range = handlers_by_type.equal_range(type);
|
||||
return range.first != range.second;
|
||||
}
|
||||
}
|
65
dbms/src/Access/MemoryAccessStorage.h
Normal file
65
dbms/src/Access/MemoryAccessStorage.h
Normal file
@ -0,0 +1,65 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/IAccessStorage.h>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Implementation of IAccessStorage which keeps all data in memory.
|
||||
class MemoryAccessStorage : public IAccessStorage
|
||||
{
|
||||
public:
|
||||
MemoryAccessStorage(const String & storage_name_ = "memory");
|
||||
~MemoryAccessStorage() override;
|
||||
|
||||
/// Sets all entities at once.
|
||||
void setAll(const std::vector<AccessEntityPtr> & all_entities);
|
||||
void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities);
|
||||
|
||||
private:
|
||||
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
|
||||
std::vector<UUID> findAllImpl(std::type_index type) const override;
|
||||
bool existsImpl(const UUID & id) const override;
|
||||
AccessEntityPtr readImpl(const UUID & id) const override;
|
||||
String readNameImpl(const UUID & id) const override;
|
||||
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
|
||||
void removeImpl(const UUID & id) override;
|
||||
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
|
||||
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
|
||||
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
|
||||
bool hasSubscriptionImpl(const UUID & id) const override;
|
||||
bool hasSubscriptionImpl(std::type_index type) const override;
|
||||
|
||||
struct Entry
|
||||
{
|
||||
UUID id;
|
||||
AccessEntityPtr entity;
|
||||
mutable std::list<OnChangedHandler> handlers_by_id;
|
||||
};
|
||||
|
||||
void insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, Notifications & notifications);
|
||||
void removeNoLock(const UUID & id, Notifications & notifications);
|
||||
void updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications);
|
||||
void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications);
|
||||
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
|
||||
|
||||
using NameTypePair = std::pair<String, std::type_index>;
|
||||
struct Hash
|
||||
{
|
||||
size_t operator()(const NameTypePair & key) const
|
||||
{
|
||||
return std::hash<String>{}(key.first) - std::hash<std::type_index>{}(key.second);
|
||||
}
|
||||
};
|
||||
|
||||
mutable std::mutex mutex;
|
||||
std::unordered_map<UUID, Entry> entries; /// We want to search entries both by ID and by the pair of name and type.
|
||||
std::unordered_map<NameTypePair, Entry *, Hash> names; /// and by the pair of name and type.
|
||||
mutable std::unordered_multimap<std::type_index, OnChangedHandler> handlers_by_type;
|
||||
std::shared_ptr<const MemoryAccessStorage *> shared_ptr_to_this; /// We need weak pointers to `this` to implement subscriptions.
|
||||
};
|
||||
}
|
246
dbms/src/Access/MultipleAccessStorage.cpp
Normal file
246
dbms/src/Access/MultipleAccessStorage.cpp
Normal file
@ -0,0 +1,246 @@
|
||||
#include <Access/MultipleAccessStorage.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/quoteString.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ACCESS_ENTITY_NOT_FOUND;
|
||||
extern const int ACCESS_ENTITY_FOUND_DUPLICATES;
|
||||
}
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename StoragePtrT>
|
||||
String joinStorageNames(const std::vector<StoragePtrT> & storages)
|
||||
{
|
||||
String result;
|
||||
for (const auto & storage : storages)
|
||||
{
|
||||
if (!result.empty())
|
||||
result += ", ";
|
||||
result += storage->getStorageName();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
MultipleAccessStorage::MultipleAccessStorage(
|
||||
std::vector<std::unique_ptr<Storage>> nested_storages_, size_t index_of_nested_storage_for_insertion_)
|
||||
: IAccessStorage(joinStorageNames(nested_storages_))
|
||||
, nested_storages(std::move(nested_storages_))
|
||||
, nested_storage_for_insertion(nested_storages[index_of_nested_storage_for_insertion_].get())
|
||||
, ids_cache(512 /* cache size */)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
MultipleAccessStorage::~MultipleAccessStorage()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> MultipleAccessStorage::findMultiple(std::type_index type, const String & name) const
|
||||
{
|
||||
std::vector<UUID> ids;
|
||||
for (const auto & nested_storage : nested_storages)
|
||||
{
|
||||
auto id = nested_storage->find(type, name);
|
||||
if (id)
|
||||
{
|
||||
std::lock_guard lock{ids_cache_mutex};
|
||||
ids_cache.set(*id, std::make_shared<Storage *>(nested_storage.get()));
|
||||
ids.push_back(*id);
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
|
||||
std::optional<UUID> MultipleAccessStorage::findImpl(std::type_index type, const String & name) const
|
||||
{
|
||||
auto ids = findMultiple(type, name);
|
||||
if (ids.empty())
|
||||
return {};
|
||||
if (ids.size() == 1)
|
||||
return ids[0];
|
||||
|
||||
std::vector<const Storage *> storages_with_duplicates;
|
||||
for (const auto & id : ids)
|
||||
{
|
||||
auto * storage = findStorage(id);
|
||||
if (storage)
|
||||
storages_with_duplicates.push_back(storage);
|
||||
}
|
||||
|
||||
throw Exception(
|
||||
"Found " + getTypeName(type) + " " + backQuote(name) + " in " + std::to_string(ids.size())
|
||||
+ " storages: " + joinStorageNames(storages_with_duplicates),
|
||||
ErrorCodes::ACCESS_ENTITY_FOUND_DUPLICATES);
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> MultipleAccessStorage::findAllImpl(std::type_index type) const
|
||||
{
|
||||
std::vector<UUID> all_ids;
|
||||
for (const auto & nested_storage : nested_storages)
|
||||
{
|
||||
auto ids = nested_storage->findAll(type);
|
||||
all_ids.insert(all_ids.end(), std::make_move_iterator(ids.begin()), std::make_move_iterator(ids.end()));
|
||||
}
|
||||
return all_ids;
|
||||
}
|
||||
|
||||
|
||||
bool MultipleAccessStorage::existsImpl(const UUID & id) const
|
||||
{
|
||||
return findStorage(id) != nullptr;
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage * MultipleAccessStorage::findStorage(const UUID & id)
|
||||
{
|
||||
{
|
||||
std::lock_guard lock{ids_cache_mutex};
|
||||
auto from_cache = ids_cache.get(id);
|
||||
if (from_cache)
|
||||
{
|
||||
auto * storage = *from_cache;
|
||||
if (storage->exists(id))
|
||||
return storage;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & nested_storage : nested_storages)
|
||||
{
|
||||
if (nested_storage->exists(id))
|
||||
{
|
||||
std::lock_guard lock{ids_cache_mutex};
|
||||
ids_cache.set(id, std::make_shared<Storage *>(nested_storage.get()));
|
||||
return nested_storage.get();
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
const IAccessStorage * MultipleAccessStorage::findStorage(const UUID & id) const
|
||||
{
|
||||
return const_cast<MultipleAccessStorage *>(this)->findStorage(id);
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage & MultipleAccessStorage::getStorage(const UUID & id)
|
||||
{
|
||||
auto * storage = findStorage(id);
|
||||
if (storage)
|
||||
return *storage;
|
||||
throwNotFound(id);
|
||||
}
|
||||
|
||||
|
||||
const IAccessStorage & MultipleAccessStorage::getStorage(const UUID & id) const
|
||||
{
|
||||
return const_cast<MultipleAccessStorage *>(this)->getStorage(id);
|
||||
}
|
||||
|
||||
|
||||
AccessEntityPtr MultipleAccessStorage::readImpl(const UUID & id) const
|
||||
{
|
||||
return getStorage(id).read(id);
|
||||
}
|
||||
|
||||
|
||||
String MultipleAccessStorage::readNameImpl(const UUID & id) const
|
||||
{
|
||||
return getStorage(id).readName(id);
|
||||
}
|
||||
|
||||
|
||||
UUID MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists)
|
||||
{
|
||||
auto id = replace_if_exists ? nested_storage_for_insertion->insertOrReplace(entity) : nested_storage_for_insertion->insert(entity);
|
||||
|
||||
std::lock_guard lock{ids_cache_mutex};
|
||||
ids_cache.set(id, std::make_shared<Storage *>(nested_storage_for_insertion));
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
void MultipleAccessStorage::removeImpl(const UUID & id)
|
||||
{
|
||||
getStorage(id).remove(id);
|
||||
}
|
||||
|
||||
|
||||
void MultipleAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func)
|
||||
{
|
||||
getStorage(id).update(id, update_func);
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr MultipleAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
|
||||
{
|
||||
auto storage = findStorage(id);
|
||||
if (!storage)
|
||||
return nullptr;
|
||||
return storage->subscribeForChanges(id, handler);
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr MultipleAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
|
||||
{
|
||||
std::vector<SubscriptionPtr> subscriptions;
|
||||
for (const auto & nested_storage : nested_storages)
|
||||
{
|
||||
auto subscription = nested_storage->subscribeForChanges(type, handler);
|
||||
if (subscription)
|
||||
subscriptions.emplace_back(std::move(subscription));
|
||||
}
|
||||
|
||||
if (subscriptions.empty())
|
||||
return nullptr;
|
||||
|
||||
if (subscriptions.size() == 1)
|
||||
return std::move(subscriptions[0]);
|
||||
|
||||
class SubscriptionImpl : public Subscription
|
||||
{
|
||||
public:
|
||||
SubscriptionImpl(std::vector<SubscriptionPtr> subscriptions_)
|
||||
: subscriptions(std::move(subscriptions_)) {}
|
||||
private:
|
||||
std::vector<SubscriptionPtr> subscriptions;
|
||||
};
|
||||
|
||||
return std::make_unique<SubscriptionImpl>(std::move(subscriptions));
|
||||
}
|
||||
|
||||
|
||||
bool MultipleAccessStorage::hasSubscriptionImpl(const UUID & id) const
|
||||
{
|
||||
for (const auto & nested_storage : nested_storages)
|
||||
{
|
||||
if (nested_storage->hasSubscription(id))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool MultipleAccessStorage::hasSubscriptionImpl(std::type_index type) const
|
||||
{
|
||||
for (const auto & nested_storage : nested_storages)
|
||||
{
|
||||
if (nested_storage->hasSubscription(type))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
53
dbms/src/Access/MultipleAccessStorage.h
Normal file
53
dbms/src/Access/MultipleAccessStorage.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/IAccessStorage.h>
|
||||
#include <Common/LRUCache.h>
|
||||
#include <mutex>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Implementation of IAccessStorage which contains multiple nested storages.
|
||||
class MultipleAccessStorage : public IAccessStorage
|
||||
{
|
||||
public:
|
||||
using Storage = IAccessStorage;
|
||||
|
||||
MultipleAccessStorage(std::vector<std::unique_ptr<Storage>> nested_storages_, size_t index_of_nested_storage_for_insertion_ = 0);
|
||||
~MultipleAccessStorage() override;
|
||||
|
||||
std::vector<UUID> findMultiple(std::type_index type, const String & name) const;
|
||||
|
||||
template <typename EntityType>
|
||||
std::vector<UUID> findMultiple(const String & name) const { return findMultiple(EntityType::TYPE, name); }
|
||||
|
||||
const Storage * findStorage(const UUID & id) const;
|
||||
Storage * findStorage(const UUID & id);
|
||||
const Storage & getStorage(const UUID & id) const;
|
||||
Storage & getStorage(const UUID & id);
|
||||
|
||||
Storage & getStorageByIndex(size_t i) { return *(nested_storages[i]); }
|
||||
const Storage & getStorageByIndex(size_t i) const { return *(nested_storages[i]); }
|
||||
|
||||
protected:
|
||||
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
|
||||
std::vector<UUID> findAllImpl(std::type_index type) const override;
|
||||
bool existsImpl(const UUID & id) const override;
|
||||
AccessEntityPtr readImpl(const UUID & id) const override;
|
||||
String readNameImpl(const UUID &id) const override;
|
||||
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
|
||||
void removeImpl(const UUID & id) override;
|
||||
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
|
||||
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
|
||||
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
|
||||
bool hasSubscriptionImpl(const UUID & id) const override;
|
||||
bool hasSubscriptionImpl(std::type_index type) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Storage>> nested_storages;
|
||||
IAccessStorage * nested_storage_for_insertion;
|
||||
mutable LRUCache<UUID, Storage *> ids_cache;
|
||||
mutable std::mutex ids_cache_mutex;
|
||||
};
|
||||
|
||||
}
|
46
dbms/src/Access/Quota.cpp
Normal file
46
dbms/src/Access/Quota.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
#include <Access/Quota.h>
|
||||
#include <boost/range/algorithm/equal.hpp>
|
||||
#include <boost/range/algorithm/fill.hpp>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
Quota::Limits::Limits()
|
||||
{
|
||||
boost::range::fill(max, 0);
|
||||
}
|
||||
|
||||
|
||||
bool operator ==(const Quota::Limits & lhs, const Quota::Limits & rhs)
|
||||
{
|
||||
return boost::range::equal(lhs.max, rhs.max) && (lhs.duration == rhs.duration)
|
||||
&& (lhs.randomize_interval == rhs.randomize_interval);
|
||||
}
|
||||
|
||||
|
||||
bool Quota::equal(const IAccessEntity & other) const
|
||||
{
|
||||
if (!IAccessEntity::equal(other))
|
||||
return false;
|
||||
const auto & other_quota = typeid_cast<const Quota &>(other);
|
||||
return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles)
|
||||
&& (all_roles == other_quota.all_roles) && (except_roles == other_quota.except_roles);
|
||||
}
|
||||
|
||||
|
||||
const char * Quota::resourceTypeToColumnName(ResourceType resource_type)
|
||||
{
|
||||
switch (resource_type)
|
||||
{
|
||||
case Quota::QUERIES: return "queries";
|
||||
case Quota::ERRORS: return "errors";
|
||||
case Quota::RESULT_ROWS: return "result_rows";
|
||||
case Quota::RESULT_BYTES: return "result_bytes";
|
||||
case Quota::READ_ROWS: return "read_rows";
|
||||
case Quota::READ_BYTES: return "read_bytes";
|
||||
case Quota::EXECUTION_TIME: return "execution_time";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
|
141
dbms/src/Access/Quota.h
Normal file
141
dbms/src/Access/Quota.h
Normal file
@ -0,0 +1,141 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/IAccessEntity.h>
|
||||
#include <chrono>
|
||||
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/** Quota for resources consumption for specific interval.
|
||||
* Used to limit resource usage by user.
|
||||
* Quota is applied "softly" - could be slightly exceed, because it is checked usually only on each block of processed data.
|
||||
* Accumulated values are not persisted and are lost on server restart.
|
||||
* Quota is local to server,
|
||||
* but for distributed queries, accumulated values for read rows and bytes
|
||||
* are collected from all participating servers and accumulated locally.
|
||||
*/
|
||||
struct Quota : public IAccessEntity
|
||||
{
|
||||
enum ResourceType
|
||||
{
|
||||
QUERIES, /// Number of queries.
|
||||
ERRORS, /// Number of queries with exceptions.
|
||||
RESULT_ROWS, /// Number of rows returned as result.
|
||||
RESULT_BYTES, /// Number of bytes returned as result.
|
||||
READ_ROWS, /// Number of rows read from tables.
|
||||
READ_BYTES, /// Number of bytes read from tables.
|
||||
EXECUTION_TIME, /// Total amount of query execution time in nanoseconds.
|
||||
};
|
||||
static constexpr size_t MAX_RESOURCE_TYPE = 7;
|
||||
|
||||
using ResourceAmount = UInt64;
|
||||
static constexpr ResourceAmount UNLIMITED = 0; /// 0 means unlimited.
|
||||
|
||||
/// Amount of resources available to consume for each duration.
|
||||
struct Limits
|
||||
{
|
||||
ResourceAmount max[MAX_RESOURCE_TYPE];
|
||||
std::chrono::seconds duration = std::chrono::seconds::zero();
|
||||
|
||||
/// Intervals can be randomized (to avoid DoS if intervals for many users end at one time).
|
||||
bool randomize_interval = false;
|
||||
|
||||
Limits();
|
||||
friend bool operator ==(const Limits & lhs, const Limits & rhs);
|
||||
friend bool operator !=(const Limits & lhs, const Limits & rhs) { return !(lhs == rhs); }
|
||||
};
|
||||
|
||||
std::vector<Limits> all_limits;
|
||||
|
||||
/// Key to share quota consumption.
|
||||
/// Users with the same key share the same amount of resource.
|
||||
enum class KeyType
|
||||
{
|
||||
NONE, /// All users share the same quota.
|
||||
USER_NAME, /// Connections with the same user name share the same quota.
|
||||
IP_ADDRESS, /// Connections from the same IP share the same quota.
|
||||
CLIENT_KEY, /// Client should explicitly supply a key to use.
|
||||
CLIENT_KEY_OR_USER_NAME, /// Same as CLIENT_KEY, but use USER_NAME if the client doesn't supply a key.
|
||||
CLIENT_KEY_OR_IP_ADDRESS, /// Same as CLIENT_KEY, but use IP_ADDRESS if the client doesn't supply a key.
|
||||
};
|
||||
static constexpr size_t MAX_KEY_TYPE = 6;
|
||||
KeyType key_type = KeyType::NONE;
|
||||
|
||||
/// Which roles or users should use this quota.
|
||||
Strings roles;
|
||||
bool all_roles = false;
|
||||
Strings except_roles;
|
||||
|
||||
bool equal(const IAccessEntity & other) const override;
|
||||
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); }
|
||||
|
||||
static const char * getNameOfResourceType(ResourceType resource_type);
|
||||
static const char * resourceTypeToKeyword(ResourceType resource_type);
|
||||
static const char * resourceTypeToColumnName(ResourceType resource_type);
|
||||
static const char * getNameOfKeyType(KeyType key_type);
|
||||
static double executionTimeToSeconds(ResourceAmount ns);
|
||||
static ResourceAmount secondsToExecutionTime(double s);
|
||||
};
|
||||
|
||||
|
||||
inline const char * Quota::getNameOfResourceType(ResourceType resource_type)
|
||||
{
|
||||
switch (resource_type)
|
||||
{
|
||||
case Quota::QUERIES: return "queries";
|
||||
case Quota::ERRORS: return "errors";
|
||||
case Quota::RESULT_ROWS: return "result rows";
|
||||
case Quota::RESULT_BYTES: return "result bytes";
|
||||
case Quota::READ_ROWS: return "read rows";
|
||||
case Quota::READ_BYTES: return "read bytes";
|
||||
case Quota::EXECUTION_TIME: return "execution time";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
inline const char * Quota::resourceTypeToKeyword(ResourceType resource_type)
|
||||
{
|
||||
switch (resource_type)
|
||||
{
|
||||
case Quota::QUERIES: return "QUERIES";
|
||||
case Quota::ERRORS: return "ERRORS";
|
||||
case Quota::RESULT_ROWS: return "RESULT ROWS";
|
||||
case Quota::RESULT_BYTES: return "RESULT BYTES";
|
||||
case Quota::READ_ROWS: return "READ ROWS";
|
||||
case Quota::READ_BYTES: return "READ BYTES";
|
||||
case Quota::EXECUTION_TIME: return "EXECUTION TIME";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
inline const char * Quota::getNameOfKeyType(KeyType key_type)
|
||||
{
|
||||
switch (key_type)
|
||||
{
|
||||
case KeyType::NONE: return "none";
|
||||
case KeyType::USER_NAME: return "user name";
|
||||
case KeyType::IP_ADDRESS: return "ip address";
|
||||
case KeyType::CLIENT_KEY: return "client key";
|
||||
case KeyType::CLIENT_KEY_OR_USER_NAME: return "client key or user name";
|
||||
case KeyType::CLIENT_KEY_OR_IP_ADDRESS: return "client key or ip address";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
inline double Quota::executionTimeToSeconds(ResourceAmount ns)
|
||||
{
|
||||
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::nanoseconds{ns}).count();
|
||||
}
|
||||
|
||||
inline Quota::ResourceAmount Quota::secondsToExecutionTime(double s)
|
||||
{
|
||||
return std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::duration<double>(s)).count();
|
||||
}
|
||||
|
||||
|
||||
using QuotaPtr = std::shared_ptr<const Quota>;
|
||||
}
|
264
dbms/src/Access/QuotaContext.cpp
Normal file
264
dbms/src/Access/QuotaContext.cpp
Normal file
@ -0,0 +1,264 @@
|
||||
#include <Access/QuotaContext.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/quoteString.h>
|
||||
#include <ext/chrono_io.h>
|
||||
#include <ext/range.h>
|
||||
#include <boost/range/algorithm/fill.hpp>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int QUOTA_EXPIRED;
|
||||
}
|
||||
|
||||
struct QuotaContext::Impl
|
||||
{
|
||||
[[noreturn]] static void throwQuotaExceed(
|
||||
const String & user_name,
|
||||
const String & quota_name,
|
||||
ResourceType resource_type,
|
||||
ResourceAmount used,
|
||||
ResourceAmount max,
|
||||
std::chrono::seconds duration,
|
||||
std::chrono::system_clock::time_point end_of_interval)
|
||||
{
|
||||
std::function<String(UInt64)> amount_to_string = [](UInt64 amount) { return std::to_string(amount); };
|
||||
if (resource_type == Quota::EXECUTION_TIME)
|
||||
amount_to_string = [&](UInt64 amount) { return ext::to_string(std::chrono::nanoseconds(amount)); };
|
||||
|
||||
throw Exception(
|
||||
"Quota for user " + backQuote(user_name) + " for " + ext::to_string(duration) + " has been exceeded: "
|
||||
+ Quota::getNameOfResourceType(resource_type) + " = " + amount_to_string(used) + "/" + amount_to_string(max) + ". "
|
||||
+ "Interval will end at " + ext::to_string(end_of_interval) + ". " + "Name of quota template: " + backQuote(quota_name),
|
||||
ErrorCodes::QUOTA_EXPIRED);
|
||||
}
|
||||
|
||||
|
||||
static std::chrono::system_clock::time_point getEndOfInterval(
|
||||
const Interval & interval, std::chrono::system_clock::time_point current_time, bool * counters_were_reset = nullptr)
|
||||
{
|
||||
auto & end_of_interval = interval.end_of_interval;
|
||||
auto end_loaded = end_of_interval.load();
|
||||
auto end = std::chrono::system_clock::time_point{end_loaded};
|
||||
if (current_time < end)
|
||||
{
|
||||
if (counters_were_reset)
|
||||
*counters_were_reset = false;
|
||||
return end;
|
||||
}
|
||||
|
||||
const auto duration = interval.duration;
|
||||
|
||||
do
|
||||
{
|
||||
end = end + (current_time - end + duration) / duration * duration;
|
||||
if (end_of_interval.compare_exchange_strong(end_loaded, end.time_since_epoch()))
|
||||
{
|
||||
boost::range::fill(interval.used, 0);
|
||||
break;
|
||||
}
|
||||
end = std::chrono::system_clock::time_point{end_loaded};
|
||||
}
|
||||
while (current_time >= end);
|
||||
|
||||
if (counters_were_reset)
|
||||
*counters_were_reset = true;
|
||||
return end;
|
||||
}
|
||||
|
||||
|
||||
static void used(
|
||||
const String & user_name,
|
||||
const Intervals & intervals,
|
||||
ResourceType resource_type,
|
||||
ResourceAmount amount,
|
||||
std::chrono::system_clock::time_point current_time,
|
||||
bool check_exceeded)
|
||||
{
|
||||
for (const auto & interval : intervals.intervals)
|
||||
{
|
||||
ResourceAmount used = (interval.used[resource_type] += amount);
|
||||
ResourceAmount max = interval.max[resource_type];
|
||||
if (max == Quota::UNLIMITED)
|
||||
continue;
|
||||
if (used > max)
|
||||
{
|
||||
bool counters_were_reset = false;
|
||||
auto end_of_interval = getEndOfInterval(interval, current_time, &counters_were_reset);
|
||||
if (counters_were_reset)
|
||||
{
|
||||
used = (interval.used[resource_type] += amount);
|
||||
if ((used > max) && check_exceeded)
|
||||
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
|
||||
}
|
||||
else if (check_exceeded)
|
||||
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void checkExceeded(
|
||||
const String & user_name,
|
||||
const Intervals & intervals,
|
||||
ResourceType resource_type,
|
||||
std::chrono::system_clock::time_point current_time)
|
||||
{
|
||||
for (const auto & interval : intervals.intervals)
|
||||
{
|
||||
ResourceAmount used = interval.used[resource_type];
|
||||
ResourceAmount max = interval.max[resource_type];
|
||||
if (max == Quota::UNLIMITED)
|
||||
continue;
|
||||
if (used > max)
|
||||
{
|
||||
bool used_counters_reset = false;
|
||||
std::chrono::system_clock::time_point end_of_interval = getEndOfInterval(interval, current_time, &used_counters_reset);
|
||||
if (!used_counters_reset)
|
||||
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void checkExceeded(
|
||||
const String & user_name,
|
||||
const Intervals & intervals,
|
||||
std::chrono::system_clock::time_point current_time)
|
||||
{
|
||||
for (auto resource_type : ext::range_with_static_cast<Quota::ResourceType>(Quota::MAX_RESOURCE_TYPE))
|
||||
checkExceeded(user_name, intervals, resource_type, current_time);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
QuotaContext::Interval & QuotaContext::Interval::operator =(const Interval & src)
|
||||
{
|
||||
randomize_interval = src.randomize_interval;
|
||||
duration = src.duration;
|
||||
end_of_interval.store(src.end_of_interval.load());
|
||||
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
|
||||
{
|
||||
max[resource_type] = src.max[resource_type];
|
||||
used[resource_type].store(src.used[resource_type].load());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
QuotaUsageInfo QuotaContext::Intervals::getUsageInfo(std::chrono::system_clock::time_point current_time) const
|
||||
{
|
||||
QuotaUsageInfo info;
|
||||
info.quota_id = quota_id;
|
||||
info.quota_name = quota_name;
|
||||
info.quota_key = quota_key;
|
||||
info.intervals.reserve(intervals.size());
|
||||
for (const auto & in : intervals)
|
||||
{
|
||||
info.intervals.push_back({});
|
||||
auto & out = info.intervals.back();
|
||||
out.duration = in.duration;
|
||||
out.randomize_interval = in.randomize_interval;
|
||||
out.end_of_interval = Impl::getEndOfInterval(in, current_time);
|
||||
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
|
||||
{
|
||||
out.max[resource_type] = in.max[resource_type];
|
||||
out.used[resource_type] = in.used[resource_type];
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
|
||||
QuotaContext::QuotaContext()
|
||||
: atomic_intervals(std::make_shared<Intervals>()) /// Unlimited quota.
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
QuotaContext::QuotaContext(
|
||||
const String & user_name_,
|
||||
const Poco::Net::IPAddress & address_,
|
||||
const String & client_key_)
|
||||
: user_name(user_name_), address(address_), client_key(client_key_)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
QuotaContext::~QuotaContext() = default;
|
||||
|
||||
|
||||
void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded)
|
||||
{
|
||||
used({resource_type, amount}, check_exceeded);
|
||||
}
|
||||
|
||||
|
||||
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded)
|
||||
{
|
||||
auto intervals_ptr = std::atomic_load(&atomic_intervals);
|
||||
auto current_time = std::chrono::system_clock::now();
|
||||
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded);
|
||||
}
|
||||
|
||||
|
||||
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded)
|
||||
{
|
||||
auto intervals_ptr = std::atomic_load(&atomic_intervals);
|
||||
auto current_time = std::chrono::system_clock::now();
|
||||
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded);
|
||||
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded);
|
||||
}
|
||||
|
||||
|
||||
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded)
|
||||
{
|
||||
auto intervals_ptr = std::atomic_load(&atomic_intervals);
|
||||
auto current_time = std::chrono::system_clock::now();
|
||||
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded);
|
||||
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded);
|
||||
Impl::used(user_name, *intervals_ptr, resource3.first, resource3.second, current_time, check_exceeded);
|
||||
}
|
||||
|
||||
|
||||
void QuotaContext::used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded)
|
||||
{
|
||||
auto intervals_ptr = std::atomic_load(&atomic_intervals);
|
||||
auto current_time = std::chrono::system_clock::now();
|
||||
for (const auto & resource : resources)
|
||||
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded);
|
||||
}
|
||||
|
||||
|
||||
void QuotaContext::checkExceeded()
|
||||
{
|
||||
auto intervals_ptr = std::atomic_load(&atomic_intervals);
|
||||
Impl::checkExceeded(user_name, *intervals_ptr, std::chrono::system_clock::now());
|
||||
}
|
||||
|
||||
|
||||
void QuotaContext::checkExceeded(ResourceType resource_type)
|
||||
{
|
||||
auto intervals_ptr = std::atomic_load(&atomic_intervals);
|
||||
Impl::checkExceeded(user_name, *intervals_ptr, resource_type, std::chrono::system_clock::now());
|
||||
}
|
||||
|
||||
|
||||
QuotaUsageInfo QuotaContext::getUsageInfo() const
|
||||
{
|
||||
auto intervals_ptr = std::atomic_load(&atomic_intervals);
|
||||
return intervals_ptr->getUsageInfo(std::chrono::system_clock::now());
|
||||
}
|
||||
|
||||
|
||||
QuotaUsageInfo::QuotaUsageInfo() : quota_id(UUID(UInt128(0)))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
QuotaUsageInfo::Interval::Interval()
|
||||
{
|
||||
boost::range::fill(used, 0);
|
||||
boost::range::fill(max, 0);
|
||||
}
|
||||
}
|
110
dbms/src/Access/QuotaContext.h
Normal file
110
dbms/src/Access/QuotaContext.h
Normal file
@ -0,0 +1,110 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/Quota.h>
|
||||
#include <Core/UUID.h>
|
||||
#include <Poco/Net/IPAddress.h>
|
||||
#include <ext/shared_ptr_helper.h>
|
||||
#include <boost/noncopyable.hpp>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct QuotaUsageInfo;
|
||||
|
||||
|
||||
/// Instances of `QuotaContext` are used to track resource consumption.
|
||||
class QuotaContext : public boost::noncopyable
|
||||
{
|
||||
public:
|
||||
using ResourceType = Quota::ResourceType;
|
||||
using ResourceAmount = Quota::ResourceAmount;
|
||||
|
||||
/// Default constructors makes an unlimited quota.
|
||||
QuotaContext();
|
||||
|
||||
~QuotaContext();
|
||||
|
||||
/// Tracks resource consumption. If the quota exceeded and `check_exceeded == true`, throws an exception.
|
||||
void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true);
|
||||
void used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded = true);
|
||||
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded = true);
|
||||
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded = true);
|
||||
void used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded = true);
|
||||
|
||||
/// Checks if the quota exceeded. If so, throws an exception.
|
||||
void checkExceeded();
|
||||
void checkExceeded(ResourceType resource_type);
|
||||
|
||||
/// Returns the information about this quota context.
|
||||
QuotaUsageInfo getUsageInfo() const;
|
||||
|
||||
private:
|
||||
friend class QuotaContextFactory;
|
||||
friend struct ext::shared_ptr_helper<QuotaContext>;
|
||||
|
||||
/// Instances of this class are created by QuotaContextFactory.
|
||||
QuotaContext(const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_);
|
||||
|
||||
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
|
||||
|
||||
struct Interval
|
||||
{
|
||||
mutable std::atomic<ResourceAmount> used[MAX_RESOURCE_TYPE];
|
||||
ResourceAmount max[MAX_RESOURCE_TYPE];
|
||||
std::chrono::seconds duration;
|
||||
bool randomize_interval;
|
||||
mutable std::atomic<std::chrono::system_clock::duration> end_of_interval;
|
||||
|
||||
Interval() {}
|
||||
Interval(const Interval & src) { *this = src; }
|
||||
Interval & operator =(const Interval & src);
|
||||
};
|
||||
|
||||
struct Intervals
|
||||
{
|
||||
std::vector<Interval> intervals;
|
||||
UUID quota_id;
|
||||
String quota_name;
|
||||
String quota_key;
|
||||
|
||||
QuotaUsageInfo getUsageInfo(std::chrono::system_clock::time_point current_time) const;
|
||||
};
|
||||
|
||||
struct Impl;
|
||||
|
||||
const String user_name;
|
||||
const Poco::Net::IPAddress address;
|
||||
const String client_key;
|
||||
std::shared_ptr<const Intervals> atomic_intervals; /// atomically changed by QuotaUsageManager
|
||||
};
|
||||
|
||||
using QuotaContextPtr = std::shared_ptr<QuotaContext>;
|
||||
|
||||
|
||||
/// The information about a quota context.
|
||||
struct QuotaUsageInfo
|
||||
{
|
||||
using ResourceType = Quota::ResourceType;
|
||||
using ResourceAmount = Quota::ResourceAmount;
|
||||
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
|
||||
|
||||
struct Interval
|
||||
{
|
||||
ResourceAmount used[MAX_RESOURCE_TYPE];
|
||||
ResourceAmount max[MAX_RESOURCE_TYPE];
|
||||
std::chrono::seconds duration = std::chrono::seconds::zero();
|
||||
bool randomize_interval = false;
|
||||
std::chrono::system_clock::time_point end_of_interval;
|
||||
Interval();
|
||||
};
|
||||
|
||||
std::vector<Interval> intervals;
|
||||
UUID quota_id;
|
||||
String quota_name;
|
||||
String quota_key;
|
||||
QuotaUsageInfo();
|
||||
};
|
||||
}
|
299
dbms/src/Access/QuotaContextFactory.cpp
Normal file
299
dbms/src/Access/QuotaContextFactory.cpp
Normal file
@ -0,0 +1,299 @@
|
||||
#include <Access/QuotaContext.h>
|
||||
#include <Access/QuotaContextFactory.h>
|
||||
#include <Access/AccessControlManager.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/thread_local_rng.h>
|
||||
#include <ext/range.h>
|
||||
#include <boost/range/adaptor/map.hpp>
|
||||
#include <boost/range/algorithm/copy.hpp>
|
||||
#include <boost/range/algorithm/lower_bound.hpp>
|
||||
#include <boost/range/algorithm/stable_sort.hpp>
|
||||
#include <boost/range/algorithm_ext/erase.hpp>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int QUOTA_REQUIRES_CLIENT_KEY;
|
||||
}
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
std::chrono::system_clock::duration randomDuration(std::chrono::seconds max)
|
||||
{
|
||||
auto count = std::chrono::duration_cast<std::chrono::system_clock::duration>(max).count();
|
||||
std::uniform_int_distribution<Int64> distribution{0, count - 1};
|
||||
return std::chrono::system_clock::duration(distribution(thread_local_rng));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void QuotaContextFactory::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUID & quota_id_)
|
||||
{
|
||||
quota = quota_;
|
||||
quota_id = quota_id_;
|
||||
|
||||
boost::range::copy(quota->roles, std::inserter(roles, roles.end()));
|
||||
all_roles = quota->all_roles;
|
||||
boost::range::copy(quota->except_roles, std::inserter(except_roles, except_roles.end()));
|
||||
|
||||
rebuildAllIntervals();
|
||||
}
|
||||
|
||||
|
||||
bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const
|
||||
{
|
||||
if (roles.count(context.user_name))
|
||||
return true;
|
||||
|
||||
if (all_roles && !except_roles.count(context.user_name))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
String QuotaContextFactory::QuotaInfo::calculateKey(const QuotaContext & context) const
|
||||
{
|
||||
using KeyType = Quota::KeyType;
|
||||
switch (quota->key_type)
|
||||
{
|
||||
case KeyType::NONE:
|
||||
return "";
|
||||
case KeyType::USER_NAME:
|
||||
return context.user_name;
|
||||
case KeyType::IP_ADDRESS:
|
||||
return context.address.toString();
|
||||
case KeyType::CLIENT_KEY:
|
||||
{
|
||||
if (!context.client_key.empty())
|
||||
return context.client_key;
|
||||
throw Exception(
|
||||
"Quota " + quota->getName() + " (for user " + context.user_name + ") requires a client supplied key.",
|
||||
ErrorCodes::QUOTA_REQUIRES_CLIENT_KEY);
|
||||
}
|
||||
case KeyType::CLIENT_KEY_OR_USER_NAME:
|
||||
{
|
||||
if (!context.client_key.empty())
|
||||
return context.client_key;
|
||||
return context.user_name;
|
||||
}
|
||||
case KeyType::CLIENT_KEY_OR_IP_ADDRESS:
|
||||
{
|
||||
if (!context.client_key.empty())
|
||||
return context.client_key;
|
||||
return context.address.toString();
|
||||
}
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key)
|
||||
{
|
||||
auto it = key_to_intervals.find(key);
|
||||
if (it != key_to_intervals.end())
|
||||
return it->second;
|
||||
return rebuildIntervals(key);
|
||||
}
|
||||
|
||||
|
||||
void QuotaContextFactory::QuotaInfo::rebuildAllIntervals()
|
||||
{
|
||||
for (const String & key : key_to_intervals | boost::adaptors::map_keys)
|
||||
rebuildIntervals(key);
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key)
|
||||
{
|
||||
auto new_intervals = std::make_shared<Intervals>();
|
||||
new_intervals->quota_name = quota->getName();
|
||||
new_intervals->quota_id = quota_id;
|
||||
new_intervals->quota_key = key;
|
||||
auto & intervals = new_intervals->intervals;
|
||||
intervals.reserve(quota->all_limits.size());
|
||||
constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
|
||||
for (const auto & limits : quota->all_limits)
|
||||
{
|
||||
intervals.emplace_back();
|
||||
auto & interval = intervals.back();
|
||||
interval.duration = limits.duration;
|
||||
std::chrono::system_clock::time_point end_of_interval{};
|
||||
interval.randomize_interval = limits.randomize_interval;
|
||||
if (limits.randomize_interval)
|
||||
end_of_interval += randomDuration(limits.duration);
|
||||
interval.end_of_interval = end_of_interval.time_since_epoch();
|
||||
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
|
||||
{
|
||||
interval.max[resource_type] = limits.max[resource_type];
|
||||
interval.used[resource_type] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Order intervals by durations from largest to smallest.
|
||||
/// To report first about largest interval on what quota was exceeded.
|
||||
struct GreaterByDuration
|
||||
{
|
||||
bool operator()(const Interval & lhs, const Interval & rhs) const { return lhs.duration > rhs.duration; }
|
||||
};
|
||||
boost::range::stable_sort(intervals, GreaterByDuration{});
|
||||
|
||||
auto it = key_to_intervals.find(key);
|
||||
if (it == key_to_intervals.end())
|
||||
{
|
||||
/// Just put new intervals into the map.
|
||||
key_to_intervals.try_emplace(key, new_intervals);
|
||||
}
|
||||
else
|
||||
{
|
||||
/// We need to keep usage information from the old intervals.
|
||||
const auto & old_intervals = it->second->intervals;
|
||||
for (auto & new_interval : new_intervals->intervals)
|
||||
{
|
||||
/// Check if an interval with the same duration is already in use.
|
||||
auto lower_bound = boost::range::lower_bound(old_intervals, new_interval, GreaterByDuration{});
|
||||
if ((lower_bound == old_intervals.end()) || (lower_bound->duration != new_interval.duration))
|
||||
continue;
|
||||
|
||||
/// Found an interval with the same duration, we need to copy its usage information to `result`.
|
||||
auto & current_interval = *lower_bound;
|
||||
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
|
||||
{
|
||||
new_interval.used[resource_type].store(current_interval.used[resource_type].load());
|
||||
new_interval.end_of_interval.store(current_interval.end_of_interval.load());
|
||||
}
|
||||
}
|
||||
it->second = new_intervals;
|
||||
}
|
||||
|
||||
return new_intervals;
|
||||
}
|
||||
|
||||
|
||||
QuotaContextFactory::QuotaContextFactory(const AccessControlManager & access_control_manager_)
|
||||
: access_control_manager(access_control_manager_)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
QuotaContextFactory::~QuotaContextFactory()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<QuotaContext> QuotaContextFactory::createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key)
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
ensureAllQuotasRead();
|
||||
auto context = ext::shared_ptr_helper<QuotaContext>::create(user_name, address, client_key);
|
||||
contexts.push_back(context);
|
||||
chooseQuotaForContext(context);
|
||||
return context;
|
||||
}
|
||||
|
||||
|
||||
void QuotaContextFactory::ensureAllQuotasRead()
|
||||
{
|
||||
/// `mutex` is already locked.
|
||||
if (all_quotas_read)
|
||||
return;
|
||||
all_quotas_read = true;
|
||||
|
||||
subscription = access_control_manager.subscribeForChanges<Quota>(
|
||||
[&](const UUID & id, const AccessEntityPtr & entity)
|
||||
{
|
||||
if (entity)
|
||||
quotaAddedOrChanged(id, typeid_cast<QuotaPtr>(entity));
|
||||
else
|
||||
quotaRemoved(id);
|
||||
});
|
||||
|
||||
for (const UUID & quota_id : access_control_manager.findAll<Quota>())
|
||||
{
|
||||
auto quota = access_control_manager.tryRead<Quota>(quota_id);
|
||||
if (quota)
|
||||
all_quotas.emplace(quota_id, QuotaInfo(quota, quota_id));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void QuotaContextFactory::quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr<const Quota> & new_quota)
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
auto it = all_quotas.find(quota_id);
|
||||
if (it == all_quotas.end())
|
||||
{
|
||||
it = all_quotas.emplace(quota_id, QuotaInfo(new_quota, quota_id)).first;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (it->second.quota == new_quota)
|
||||
return;
|
||||
}
|
||||
|
||||
auto & info = it->second;
|
||||
info.setQuota(new_quota, quota_id);
|
||||
chooseQuotaForAllContexts();
|
||||
}
|
||||
|
||||
|
||||
void QuotaContextFactory::quotaRemoved(const UUID & quota_id)
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
all_quotas.erase(quota_id);
|
||||
chooseQuotaForAllContexts();
|
||||
}
|
||||
|
||||
|
||||
void QuotaContextFactory::chooseQuotaForAllContexts()
|
||||
{
|
||||
/// `mutex` is already locked.
|
||||
boost::range::remove_erase_if(
|
||||
contexts,
|
||||
[&](const std::weak_ptr<QuotaContext> & weak)
|
||||
{
|
||||
auto context = weak.lock();
|
||||
if (!context)
|
||||
return true; // remove from the `contexts` list.
|
||||
chooseQuotaForContext(context);
|
||||
return false; // keep in the `contexts` list.
|
||||
});
|
||||
}
|
||||
|
||||
void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context)
|
||||
{
|
||||
/// `mutex` is already locked.
|
||||
std::shared_ptr<const Intervals> intervals;
|
||||
for (auto & info : all_quotas | boost::adaptors::map_values)
|
||||
{
|
||||
if (info.canUseWithContext(*context))
|
||||
{
|
||||
String key = info.calculateKey(*context);
|
||||
intervals = info.getOrBuildIntervals(key);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!intervals)
|
||||
intervals = std::make_shared<Intervals>(); /// No quota == no limits.
|
||||
|
||||
std::atomic_store(&context->atomic_intervals, intervals);
|
||||
}
|
||||
|
||||
|
||||
std::vector<QuotaUsageInfo> QuotaContextFactory::getUsageInfo() const
|
||||
{
|
||||
std::lock_guard lock{mutex};
|
||||
std::vector<QuotaUsageInfo> all_infos;
|
||||
auto current_time = std::chrono::system_clock::now();
|
||||
for (const auto & info : all_quotas | boost::adaptors::map_values)
|
||||
{
|
||||
for (const auto & intervals : info.key_to_intervals | boost::adaptors::map_values)
|
||||
all_infos.push_back(intervals->getUsageInfo(current_time));
|
||||
}
|
||||
return all_infos;
|
||||
}
|
||||
}
|
62
dbms/src/Access/QuotaContextFactory.h
Normal file
62
dbms/src/Access/QuotaContextFactory.h
Normal file
@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/QuotaContext.h>
|
||||
#include <Access/IAccessStorage.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class AccessControlManager;
|
||||
|
||||
|
||||
/// Stores information how much amount of resources have been consumed and how much are left.
|
||||
class QuotaContextFactory
|
||||
{
|
||||
public:
|
||||
QuotaContextFactory(const AccessControlManager & access_control_manager_);
|
||||
~QuotaContextFactory();
|
||||
|
||||
QuotaContextPtr createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key);
|
||||
std::vector<QuotaUsageInfo> getUsageInfo() const;
|
||||
|
||||
private:
|
||||
using Interval = QuotaContext::Interval;
|
||||
using Intervals = QuotaContext::Intervals;
|
||||
|
||||
struct QuotaInfo
|
||||
{
|
||||
QuotaInfo(const QuotaPtr & quota_, const UUID & quota_id_) { setQuota(quota_, quota_id_); }
|
||||
void setQuota(const QuotaPtr & quota_, const UUID & quota_id_);
|
||||
|
||||
bool canUseWithContext(const QuotaContext & context) const;
|
||||
String calculateKey(const QuotaContext & context) const;
|
||||
std::shared_ptr<const Intervals> getOrBuildIntervals(const String & key);
|
||||
std::shared_ptr<const Intervals> rebuildIntervals(const String & key);
|
||||
void rebuildAllIntervals();
|
||||
|
||||
QuotaPtr quota;
|
||||
UUID quota_id;
|
||||
std::unordered_set<String> roles;
|
||||
bool all_roles = false;
|
||||
std::unordered_set<String> except_roles;
|
||||
std::unordered_map<String /* quota key */, std::shared_ptr<const Intervals>> key_to_intervals;
|
||||
};
|
||||
|
||||
void ensureAllQuotasRead();
|
||||
void quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr<const Quota> & new_quota);
|
||||
void quotaRemoved(const UUID & quota_id);
|
||||
void chooseQuotaForAllContexts();
|
||||
void chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context);
|
||||
|
||||
const AccessControlManager & access_control_manager;
|
||||
mutable std::mutex mutex;
|
||||
std::unordered_map<UUID /* quota id */, QuotaInfo> all_quotas;
|
||||
bool all_quotas_read = false;
|
||||
IAccessStorage::SubscriptionPtr subscription;
|
||||
std::vector<std::weak_ptr<QuotaContext>> contexts;
|
||||
};
|
||||
}
|
207
dbms/src/Access/UsersConfigAccessStorage.cpp
Normal file
207
dbms/src/Access/UsersConfigAccessStorage.cpp
Normal file
@ -0,0 +1,207 @@
|
||||
#include <Access/UsersConfigAccessStorage.h>
|
||||
#include <Access/Quota.h>
|
||||
#include <Common/StringUtils/StringUtils.h>
|
||||
#include <Common/quoteString.h>
|
||||
#include <Poco/Util/AbstractConfiguration.h>
|
||||
#include <Poco/MD5Engine.h>
|
||||
#include <cstring>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace
|
||||
{
|
||||
char getTypeChar(std::type_index type)
|
||||
{
|
||||
if (type == typeid(Quota))
|
||||
return 'Q';
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
UUID generateID(std::type_index type, const String & name)
|
||||
{
|
||||
Poco::MD5Engine md5;
|
||||
md5.update(name);
|
||||
char type_storage_chars[] = " USRSXML";
|
||||
type_storage_chars[0] = getTypeChar(type);
|
||||
md5.update(type_storage_chars, strlen(type_storage_chars));
|
||||
UUID result;
|
||||
memcpy(&result, md5.digest().data(), md5.digestLength());
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
UUID generateID(const IAccessEntity & entity) { return generateID(entity.getType(), entity.getFullName()); }
|
||||
|
||||
QuotaPtr parseQuota(const Poco::Util::AbstractConfiguration & config, const String & quota_name, const Strings & user_names)
|
||||
{
|
||||
auto quota = std::make_shared<Quota>();
|
||||
quota->setName(quota_name);
|
||||
|
||||
using KeyType = Quota::KeyType;
|
||||
String quota_config = "quotas." + quota_name;
|
||||
if (config.has(quota_config + ".keyed_by_ip"))
|
||||
quota->key_type = KeyType::IP_ADDRESS;
|
||||
else if (config.has(quota_config + ".keyed"))
|
||||
quota->key_type = KeyType::CLIENT_KEY_OR_USER_NAME;
|
||||
else
|
||||
quota->key_type = KeyType::USER_NAME;
|
||||
|
||||
Poco::Util::AbstractConfiguration::Keys interval_keys;
|
||||
config.keys(quota_config, interval_keys);
|
||||
|
||||
for (const String & interval_key : interval_keys)
|
||||
{
|
||||
if (!startsWith(interval_key, "interval"))
|
||||
continue;
|
||||
|
||||
String interval_config = quota_config + "." + interval_key;
|
||||
std::chrono::seconds duration{config.getInt(interval_config + ".duration", 0)};
|
||||
if (duration.count() <= 0) /// Skip quotas with non-positive duration.
|
||||
continue;
|
||||
|
||||
quota->all_limits.emplace_back();
|
||||
auto & limits = quota->all_limits.back();
|
||||
limits.duration = duration;
|
||||
limits.randomize_interval = config.getBool(interval_config + ".randomize", false);
|
||||
|
||||
using ResourceType = Quota::ResourceType;
|
||||
limits.max[ResourceType::QUERIES] = config.getUInt64(interval_config + ".queries", Quota::UNLIMITED);
|
||||
limits.max[ResourceType::ERRORS] = config.getUInt64(interval_config + ".errors", Quota::UNLIMITED);
|
||||
limits.max[ResourceType::RESULT_ROWS] = config.getUInt64(interval_config + ".result_rows", Quota::UNLIMITED);
|
||||
limits.max[ResourceType::RESULT_BYTES] = config.getUInt64(interval_config + ".result_bytes", Quota::UNLIMITED);
|
||||
limits.max[ResourceType::READ_ROWS] = config.getUInt64(interval_config + ".read_rows", Quota::UNLIMITED);
|
||||
limits.max[ResourceType::READ_BYTES] = config.getUInt64(interval_config + ".read_bytes", Quota::UNLIMITED);
|
||||
limits.max[ResourceType::EXECUTION_TIME] = Quota::secondsToExecutionTime(config.getUInt64(interval_config + ".execution_time", Quota::UNLIMITED));
|
||||
}
|
||||
|
||||
quota->roles = user_names;
|
||||
|
||||
return quota;
|
||||
}
|
||||
|
||||
|
||||
std::vector<AccessEntityPtr> parseQuotas(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log)
|
||||
{
|
||||
Poco::Util::AbstractConfiguration::Keys user_names;
|
||||
config.keys("users", user_names);
|
||||
std::unordered_map<String, Strings> quota_to_user_names;
|
||||
for (const auto & user_name : user_names)
|
||||
{
|
||||
if (config.has("users." + user_name + ".quota"))
|
||||
quota_to_user_names[config.getString("users." + user_name + ".quota")].push_back(user_name);
|
||||
}
|
||||
|
||||
Poco::Util::AbstractConfiguration::Keys quota_names;
|
||||
config.keys("quotas", quota_names);
|
||||
std::vector<AccessEntityPtr> quotas;
|
||||
quotas.reserve(quota_names.size());
|
||||
for (const auto & quota_name : quota_names)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto it = quota_to_user_names.find(quota_name);
|
||||
const Strings quota_users = (it != quota_to_user_names.end()) ? std::move(it->second) : Strings{};
|
||||
quotas.push_back(parseQuota(config, quota_name, quota_users));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log, "Could not parse quota " + backQuote(quota_name));
|
||||
}
|
||||
}
|
||||
return quotas;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
UsersConfigAccessStorage::UsersConfigAccessStorage() : IAccessStorage("users.xml")
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
UsersConfigAccessStorage::~UsersConfigAccessStorage() {}
|
||||
|
||||
|
||||
void UsersConfigAccessStorage::loadFromConfig(const Poco::Util::AbstractConfiguration & config)
|
||||
{
|
||||
std::vector<std::pair<UUID, AccessEntityPtr>> all_entities;
|
||||
for (const auto & entity : parseQuotas(config, getLogger()))
|
||||
all_entities.emplace_back(generateID(*entity), entity);
|
||||
memory_storage.setAll(all_entities);
|
||||
}
|
||||
|
||||
|
||||
std::optional<UUID> UsersConfigAccessStorage::findImpl(std::type_index type, const String & name) const
|
||||
{
|
||||
return memory_storage.find(type, name);
|
||||
}
|
||||
|
||||
|
||||
std::vector<UUID> UsersConfigAccessStorage::findAllImpl(std::type_index type) const
|
||||
{
|
||||
return memory_storage.findAll(type);
|
||||
}
|
||||
|
||||
|
||||
bool UsersConfigAccessStorage::existsImpl(const UUID & id) const
|
||||
{
|
||||
return memory_storage.exists(id);
|
||||
}
|
||||
|
||||
|
||||
AccessEntityPtr UsersConfigAccessStorage::readImpl(const UUID & id) const
|
||||
{
|
||||
return memory_storage.read(id);
|
||||
}
|
||||
|
||||
|
||||
String UsersConfigAccessStorage::readNameImpl(const UUID & id) const
|
||||
{
|
||||
return memory_storage.readName(id);
|
||||
}
|
||||
|
||||
|
||||
UUID UsersConfigAccessStorage::insertImpl(const AccessEntityPtr & entity, bool)
|
||||
{
|
||||
throwReadonlyCannotInsert(entity->getType(), entity->getFullName());
|
||||
}
|
||||
|
||||
|
||||
void UsersConfigAccessStorage::removeImpl(const UUID & id)
|
||||
{
|
||||
auto entity = read(id);
|
||||
throwReadonlyCannotRemove(entity->getType(), entity->getFullName());
|
||||
}
|
||||
|
||||
|
||||
void UsersConfigAccessStorage::updateImpl(const UUID & id, const UpdateFunc &)
|
||||
{
|
||||
auto entity = read(id);
|
||||
throwReadonlyCannotUpdate(entity->getType(), entity->getFullName());
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr UsersConfigAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
|
||||
{
|
||||
return memory_storage.subscribeForChanges(id, handler);
|
||||
}
|
||||
|
||||
|
||||
IAccessStorage::SubscriptionPtr UsersConfigAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
|
||||
{
|
||||
return memory_storage.subscribeForChanges(type, handler);
|
||||
}
|
||||
|
||||
|
||||
bool UsersConfigAccessStorage::hasSubscriptionImpl(const UUID & id) const
|
||||
{
|
||||
return memory_storage.hasSubscription(id);
|
||||
}
|
||||
|
||||
|
||||
bool UsersConfigAccessStorage::hasSubscriptionImpl(std::type_index type) const
|
||||
{
|
||||
return memory_storage.hasSubscription(type);
|
||||
}
|
||||
}
|
42
dbms/src/Access/UsersConfigAccessStorage.h
Normal file
42
dbms/src/Access/UsersConfigAccessStorage.h
Normal file
@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include <Access/MemoryAccessStorage.h>
|
||||
|
||||
|
||||
namespace Poco
|
||||
{
|
||||
namespace Util
|
||||
{
|
||||
class AbstractConfiguration;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Implementation of IAccessStorage which loads all from users.xml periodically.
|
||||
class UsersConfigAccessStorage : public IAccessStorage
|
||||
{
|
||||
public:
|
||||
UsersConfigAccessStorage();
|
||||
~UsersConfigAccessStorage() override;
|
||||
|
||||
void loadFromConfig(const Poco::Util::AbstractConfiguration & config);
|
||||
|
||||
private:
|
||||
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
|
||||
std::vector<UUID> findAllImpl(std::type_index type) const override;
|
||||
bool existsImpl(const UUID & id) const override;
|
||||
AccessEntityPtr readImpl(const UUID & id) const override;
|
||||
String readNameImpl(const UUID & id) const override;
|
||||
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
|
||||
void removeImpl(const UUID & id) override;
|
||||
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
|
||||
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
|
||||
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
|
||||
bool hasSubscriptionImpl(const UUID & id) const override;
|
||||
bool hasSubscriptionImpl(std::type_index type) const override;
|
||||
|
||||
MemoryAccessStorage memory_storage;
|
||||
};
|
||||
}
|
@ -439,6 +439,10 @@ void Connection::sendQuery(
|
||||
|
||||
void Connection::sendCancel()
|
||||
{
|
||||
/// If we already disconnected.
|
||||
if (!out)
|
||||
return;
|
||||
|
||||
//LOG_TRACE(log_wrapper.get(), "Sending cancel");
|
||||
|
||||
writeVarUInt(Protocol::Client::Cancel, *out);
|
||||
|
@ -4,6 +4,9 @@
|
||||
|
||||
#if USE_ICU
|
||||
#include <unicode/ucol.h>
|
||||
#include <unicode/unistr.h>
|
||||
#include <unicode/locid.h>
|
||||
#include <unicode/ucnv.h>
|
||||
#else
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic ignored "-Wunused-private-field"
|
||||
@ -14,6 +17,7 @@
|
||||
#include <Common/Exception.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Poco/String.h>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -26,16 +30,81 @@ namespace DB
|
||||
}
|
||||
}
|
||||
|
||||
Collator::Collator(const std::string & locale_) : locale(Poco::toLower(locale_))
|
||||
|
||||
AvailableCollationLocales::AvailableCollationLocales()
|
||||
{
|
||||
#if USE_ICU
|
||||
static const size_t MAX_LANG_LENGTH = 128;
|
||||
size_t available_locales_count = ucol_countAvailable();
|
||||
for (size_t i = 0; i < available_locales_count; ++i)
|
||||
{
|
||||
std::string locale_name = ucol_getAvailable(i);
|
||||
UChar lang_buffer[MAX_LANG_LENGTH];
|
||||
char normal_buf[MAX_LANG_LENGTH];
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
|
||||
/// All names will be in English language
|
||||
size_t lang_length = uloc_getDisplayLanguage(
|
||||
locale_name.c_str(), "en", lang_buffer, MAX_LANG_LENGTH, &status);
|
||||
std::optional<std::string> lang;
|
||||
|
||||
if (!U_FAILURE(status))
|
||||
{
|
||||
/// Convert language name from UChar array to normal char array.
|
||||
/// We use English language for name, so all UChar's length is equal to sizeof(char)
|
||||
u_UCharsToChars(lang_buffer, normal_buf, lang_length);
|
||||
lang.emplace(std::string(normal_buf, lang_length));
|
||||
}
|
||||
|
||||
locales_map.emplace(Poco::toLower(locale_name), LocaleAndLanguage{locale_name, lang});
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
const AvailableCollationLocales & AvailableCollationLocales::instance()
|
||||
{
|
||||
static AvailableCollationLocales instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
AvailableCollationLocales::LocalesVector AvailableCollationLocales::getAvailableCollations() const
|
||||
{
|
||||
LocalesVector result;
|
||||
for (const auto & name_and_locale : locales_map)
|
||||
result.push_back(name_and_locale.second);
|
||||
|
||||
auto comparator = [] (const LocaleAndLanguage & f, const LocaleAndLanguage & s)
|
||||
{
|
||||
return f.locale_name < s.locale_name;
|
||||
};
|
||||
std::sort(result.begin(), result.end(), comparator);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool AvailableCollationLocales::isCollationSupported(const std::string & locale_name) const
|
||||
{
|
||||
/// We support locale names in any case, so we have to convert all to lower case
|
||||
return locales_map.count(Poco::toLower(locale_name));
|
||||
}
|
||||
|
||||
Collator::Collator(const std::string & locale_)
|
||||
: locale(Poco::toLower(locale_))
|
||||
{
|
||||
#if USE_ICU
|
||||
/// We check it here, because ucol_open will fallback to default locale for
|
||||
/// almost all random names.
|
||||
if (!AvailableCollationLocales::instance().isCollationSupported(locale))
|
||||
throw DB::Exception("Unsupported collation locale: " + locale, DB::ErrorCodes::UNSUPPORTED_COLLATION_LOCALE);
|
||||
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
|
||||
collator = ucol_open(locale.c_str(), &status);
|
||||
if (status != U_ZERO_ERROR)
|
||||
if (U_FAILURE(status))
|
||||
{
|
||||
ucol_close(collator);
|
||||
throw DB::Exception("Unsupported collation locale: " + locale, DB::ErrorCodes::UNSUPPORTED_COLLATION_LOCALE);
|
||||
throw DB::Exception("Failed to open locale: " + locale + " with error: " + u_errorName(status), DB::ErrorCodes::UNSUPPORTED_COLLATION_LOCALE);
|
||||
}
|
||||
#else
|
||||
throw DB::Exception("Collations support is disabled, because ClickHouse was built without ICU library", DB::ErrorCodes::SUPPORT_IS_DISABLED);
|
||||
@ -60,8 +129,8 @@ int Collator::compare(const char * str1, size_t length1, const char * str2, size
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
UCollationResult compare_result = ucol_strcollIter(collator, &iter1, &iter2, &status);
|
||||
|
||||
if (status != U_ZERO_ERROR)
|
||||
throw DB::Exception("ICU collation comparison failed with error code: " + DB::toString<int>(status),
|
||||
if (U_FAILURE(status))
|
||||
throw DB::Exception("ICU collation comparison failed with error code: " + std::string(u_errorName(status)),
|
||||
DB::ErrorCodes::COLLATION_COMPARISON_FAILED);
|
||||
|
||||
/** Values of enum UCollationResult are equals to what exactly we need:
|
||||
@ -83,14 +152,3 @@ const std::string & Collator::getLocale() const
|
||||
{
|
||||
return locale;
|
||||
}
|
||||
|
||||
std::vector<std::string> Collator::getAvailableCollations()
|
||||
{
|
||||
std::vector<std::string> result;
|
||||
#if USE_ICU
|
||||
size_t available_locales_count = ucol_countAvailable();
|
||||
for (size_t i = 0; i < available_locales_count; ++i)
|
||||
result.push_back(ucol_getAvailable(i));
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
@ -3,9 +3,39 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <boost/noncopyable.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
struct UCollator;
|
||||
|
||||
/// Class represents available locales for collations.
|
||||
class AvailableCollationLocales : private boost::noncopyable
|
||||
{
|
||||
public:
|
||||
|
||||
struct LocaleAndLanguage
|
||||
{
|
||||
std::string locale_name; /// ISO locale code
|
||||
std::optional<std::string> language; /// full language name in English
|
||||
};
|
||||
|
||||
using AvailableLocalesMap = std::unordered_map<std::string, LocaleAndLanguage>;
|
||||
using LocalesVector = std::vector<LocaleAndLanguage>;
|
||||
|
||||
static const AvailableCollationLocales & instance();
|
||||
|
||||
/// Get all collations with names in sorted order
|
||||
LocalesVector getAvailableCollations() const;
|
||||
|
||||
/// Check that collation is supported
|
||||
bool isCollationSupported(const std::string & locale_name) const;
|
||||
|
||||
private:
|
||||
AvailableCollationLocales();
|
||||
private:
|
||||
AvailableLocalesMap locales_map;
|
||||
};
|
||||
|
||||
class Collator : private boost::noncopyable
|
||||
{
|
||||
public:
|
||||
@ -15,10 +45,8 @@ public:
|
||||
int compare(const char * str1, size_t length1, const char * str2, size_t length2) const;
|
||||
|
||||
const std::string & getLocale() const;
|
||||
|
||||
static std::vector<std::string> getAvailableCollations();
|
||||
|
||||
private:
|
||||
|
||||
std::string locale;
|
||||
UCollator * collator;
|
||||
};
|
||||
|
@ -105,6 +105,11 @@ public:
|
||||
return data->getFloat64(0);
|
||||
}
|
||||
|
||||
Float32 getFloat32(size_t) const override
|
||||
{
|
||||
return data->getFloat32(0);
|
||||
}
|
||||
|
||||
bool isNullAt(size_t) const override
|
||||
{
|
||||
return data->isNullAt(0);
|
||||
@ -219,6 +224,7 @@ public:
|
||||
|
||||
Field getField() const { return getDataColumn()[0]; }
|
||||
|
||||
/// The constant value. It is valid even if the size of the column is 0.
|
||||
template <typename T>
|
||||
T getValue() const { return getField().safeGet<NearestFieldType<T>>(); }
|
||||
};
|
||||
|
@ -96,6 +96,7 @@ public:
|
||||
void insertFrom(const IColumn & src, size_t n) override { data.push_back(static_cast<const Self &>(src).getData()[n]); }
|
||||
void insertData(const char * pos, size_t /*length*/) override;
|
||||
void insertDefault() override { data.push_back(T()); }
|
||||
virtual void insertManyDefaults(size_t length) override { data.resize_fill(data.size() + length); }
|
||||
void insert(const Field & x) override { data.push_back(DB::get<NearestFieldType<T>>(x)); }
|
||||
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
|
||||
|
||||
@ -144,7 +145,7 @@ public:
|
||||
}
|
||||
|
||||
|
||||
void insert(const T value) { data.push_back(value); }
|
||||
void insertValue(const T value) { data.push_back(value); }
|
||||
Container & getData() { return data; }
|
||||
const Container & getData() const { return data; }
|
||||
const T & getElement(size_t n) const { return data[n]; }
|
||||
|
@ -92,6 +92,11 @@ public:
|
||||
chars.resize_fill(chars.size() + n);
|
||||
}
|
||||
|
||||
virtual void insertManyDefaults(size_t length) override
|
||||
{
|
||||
chars.resize_fill(chars.size() + n * length);
|
||||
}
|
||||
|
||||
void popBack(size_t elems) override
|
||||
{
|
||||
chars.resize_assume_reserved(chars.size() - n * elems);
|
||||
|
@ -59,6 +59,7 @@ public:
|
||||
UInt64 getUInt(size_t n) const override { return getDictionary().getUInt(getIndexes().getUInt(n)); }
|
||||
Int64 getInt(size_t n) const override { return getDictionary().getInt(getIndexes().getUInt(n)); }
|
||||
Float64 getFloat64(size_t n) const override { return getDictionary().getInt(getIndexes().getFloat64(n)); }
|
||||
Float32 getFloat32(size_t n) const override { return getDictionary().getInt(getIndexes().getFloat32(n)); }
|
||||
bool getBool(size_t n) const override { return getDictionary().getInt(getIndexes().getBool(n)); }
|
||||
bool isNullAt(size_t n) const override { return getDictionary().isNullAt(getIndexes().getUInt(n)); }
|
||||
ColumnPtr cut(size_t start, size_t length) const override
|
||||
|
@ -205,6 +205,13 @@ public:
|
||||
offsets.push_back(offsets.back() + 1);
|
||||
}
|
||||
|
||||
virtual void insertManyDefaults(size_t length) override
|
||||
{
|
||||
chars.resize_fill(chars.size() + length);
|
||||
for (size_t i = 0; i < length; ++i)
|
||||
offsets.push_back(offsets.back() + 1);
|
||||
}
|
||||
|
||||
int compareAt(size_t n, size_t m, const IColumn & rhs_, int /*nan_direction_hint*/) const override
|
||||
{
|
||||
const ColumnString & rhs = assert_cast<const ColumnString &>(rhs_);
|
||||
|
@ -66,6 +66,7 @@ public:
|
||||
UInt64 getUInt(size_t n) const override { return getNestedColumn()->getUInt(n); }
|
||||
Int64 getInt(size_t n) const override { return getNestedColumn()->getInt(n); }
|
||||
Float64 getFloat64(size_t n) const override { return getNestedColumn()->getFloat64(n); }
|
||||
Float32 getFloat32(size_t n) const override { return getNestedColumn()->getFloat32(n); }
|
||||
bool getBool(size_t n) const override { return getNestedColumn()->getBool(n); }
|
||||
bool isNullAt(size_t n) const override { return is_nullable && n == getNullValueIndex(); }
|
||||
StringRef serializeValueIntoArena(size_t n, Arena & arena, char const *& begin) const override;
|
||||
|
@ -222,6 +222,12 @@ Float64 ColumnVector<T>::getFloat64(size_t n) const
|
||||
return static_cast<Float64>(data[n]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Float32 ColumnVector<T>::getFloat32(size_t n) const
|
||||
{
|
||||
return static_cast<Float32>(data[n]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ColumnVector<T>::insertRangeFrom(const IColumn & src, size_t start, size_t length)
|
||||
{
|
||||
|
@ -144,6 +144,11 @@ public:
|
||||
data.push_back(T());
|
||||
}
|
||||
|
||||
virtual void insertManyDefaults(size_t length) override
|
||||
{
|
||||
data.resize_fill(data.size() + length, T());
|
||||
}
|
||||
|
||||
void popBack(size_t n) override
|
||||
{
|
||||
data.resize_assume_reserved(data.size() - n);
|
||||
@ -205,6 +210,7 @@ public:
|
||||
UInt64 get64(size_t n) const override;
|
||||
|
||||
Float64 getFloat64(size_t n) const override;
|
||||
Float32 getFloat32(size_t n) const override;
|
||||
|
||||
UInt64 getUInt(size_t n) const override
|
||||
{
|
||||
|
@ -100,6 +100,11 @@ public:
|
||||
throw Exception("Method getFloat64 is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual Float32 getFloat32(size_t /*n*/) const
|
||||
{
|
||||
throw Exception("Method getFloat32 is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
/** If column is numeric, return value of n-th element, casted to UInt64.
|
||||
* For NULL values of Nullable column it is allowed to return arbitrary value.
|
||||
* Otherwise throw an exception.
|
||||
|
@ -34,6 +34,7 @@ namespace ErrorCodes
|
||||
extern const int CANNOT_STATVFS;
|
||||
extern const int NOT_ENOUGH_SPACE;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
extern const int NO_SUCH_DATA_PART;
|
||||
extern const int SYSTEM_ERROR;
|
||||
extern const int UNKNOWN_ELEMENT_IN_CONFIG;
|
||||
extern const int EXCESSIVE_ELEMENT_IN_CONFIG;
|
||||
|
@ -464,6 +464,15 @@ namespace ErrorCodes
|
||||
extern const int CANNOT_GET_CREATE_DICTIONARY_QUERY = 487;
|
||||
extern const int UNKNOWN_DICTIONARY = 488;
|
||||
extern const int INCORRECT_DICTIONARY_DEFINITION = 489;
|
||||
extern const int CANNOT_FORMAT_DATETIME = 490;
|
||||
extern const int UNACCEPTABLE_URL = 491;
|
||||
extern const int ACCESS_ENTITY_NOT_FOUND = 492;
|
||||
extern const int ACCESS_ENTITY_ALREADY_EXISTS = 493;
|
||||
extern const int ACCESS_ENTITY_FOUND_DUPLICATES = 494;
|
||||
extern const int ACCESS_ENTITY_STORAGE_READONLY = 495;
|
||||
extern const int QUOTA_REQUIRES_CLIENT_KEY = 496;
|
||||
extern const int NOT_ENOUGH_PRIVILEGES = 497;
|
||||
extern const int LIMIT_BY_WITH_TIES_IS_NOT_SUPPORTED = 498;
|
||||
|
||||
extern const int KEEPER_EXCEPTION = 999;
|
||||
extern const int POCO_EXCEPTION = 1000;
|
||||
|
@ -34,97 +34,23 @@ struct StaticVisitor
|
||||
|
||||
/// F is template parameter, to allow universal reference for field, that is useful for const and non-const values.
|
||||
template <typename Visitor, typename F>
|
||||
typename std::decay_t<Visitor>::ResultType applyVisitor(Visitor && visitor, F && field)
|
||||
auto applyVisitor(Visitor && visitor, F && field)
|
||||
{
|
||||
switch (field.getType())
|
||||
{
|
||||
case Field::Types::Null: return visitor(field.template get<Null>());
|
||||
case Field::Types::UInt64: return visitor(field.template get<UInt64>());
|
||||
case Field::Types::UInt128: return visitor(field.template get<UInt128>());
|
||||
case Field::Types::Int64: return visitor(field.template get<Int64>());
|
||||
case Field::Types::Float64: return visitor(field.template get<Float64>());
|
||||
case Field::Types::String: return visitor(field.template get<String>());
|
||||
case Field::Types::Array: return visitor(field.template get<Array>());
|
||||
case Field::Types::Tuple: return visitor(field.template get<Tuple>());
|
||||
case Field::Types::Decimal32: return visitor(field.template get<DecimalField<Decimal32>>());
|
||||
case Field::Types::Decimal64: return visitor(field.template get<DecimalField<Decimal64>>());
|
||||
case Field::Types::Decimal128: return visitor(field.template get<DecimalField<Decimal128>>());
|
||||
case Field::Types::AggregateFunctionState: return visitor(field.template get<AggregateFunctionStateData>());
|
||||
|
||||
default:
|
||||
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Visitor, typename F1, typename F2>
|
||||
static typename std::decay_t<Visitor>::ResultType applyBinaryVisitorImpl(Visitor && visitor, F1 && field1, F2 && field2)
|
||||
{
|
||||
switch (field2.getType())
|
||||
{
|
||||
case Field::Types::Null: return visitor(field1, field2.template get<Null>());
|
||||
case Field::Types::UInt64: return visitor(field1, field2.template get<UInt64>());
|
||||
case Field::Types::UInt128: return visitor(field1, field2.template get<UInt128>());
|
||||
case Field::Types::Int64: return visitor(field1, field2.template get<Int64>());
|
||||
case Field::Types::Float64: return visitor(field1, field2.template get<Float64>());
|
||||
case Field::Types::String: return visitor(field1, field2.template get<String>());
|
||||
case Field::Types::Array: return visitor(field1, field2.template get<Array>());
|
||||
case Field::Types::Tuple: return visitor(field1, field2.template get<Tuple>());
|
||||
case Field::Types::Decimal32: return visitor(field1, field2.template get<DecimalField<Decimal32>>());
|
||||
case Field::Types::Decimal64: return visitor(field1, field2.template get<DecimalField<Decimal64>>());
|
||||
case Field::Types::Decimal128: return visitor(field1, field2.template get<DecimalField<Decimal128>>());
|
||||
case Field::Types::AggregateFunctionState: return visitor(field1, field2.template get<AggregateFunctionStateData>());
|
||||
|
||||
default:
|
||||
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
|
||||
}
|
||||
return Field::dispatch(visitor, field);
|
||||
}
|
||||
|
||||
template <typename Visitor, typename F1, typename F2>
|
||||
typename std::decay_t<Visitor>::ResultType applyVisitor(Visitor && visitor, F1 && field1, F2 && field2)
|
||||
auto applyVisitor(Visitor && visitor, F1 && field1, F2 && field2)
|
||||
{
|
||||
switch (field1.getType())
|
||||
{
|
||||
case Field::Types::Null:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<Null>(), std::forward<F2>(field2));
|
||||
case Field::Types::UInt64:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<UInt64>(), std::forward<F2>(field2));
|
||||
case Field::Types::UInt128:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<UInt128>(), std::forward<F2>(field2));
|
||||
case Field::Types::Int64:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<Int64>(), std::forward<F2>(field2));
|
||||
case Field::Types::Float64:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<Float64>(), std::forward<F2>(field2));
|
||||
case Field::Types::String:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<String>(), std::forward<F2>(field2));
|
||||
case Field::Types::Array:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<Array>(), std::forward<F2>(field2));
|
||||
case Field::Types::Tuple:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<Tuple>(), std::forward<F2>(field2));
|
||||
case Field::Types::Decimal32:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<DecimalField<Decimal32>>(), std::forward<F2>(field2));
|
||||
case Field::Types::Decimal64:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<DecimalField<Decimal64>>(), std::forward<F2>(field2));
|
||||
case Field::Types::Decimal128:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<DecimalField<Decimal128>>(), std::forward<F2>(field2));
|
||||
case Field::Types::AggregateFunctionState:
|
||||
return applyBinaryVisitorImpl(
|
||||
std::forward<Visitor>(visitor), field1.template get<AggregateFunctionStateData>(), std::forward<F2>(field2));
|
||||
|
||||
default:
|
||||
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
|
||||
}
|
||||
return Field::dispatch([&](auto & field1_value)
|
||||
{
|
||||
return Field::dispatch([&](auto & field2_value)
|
||||
{
|
||||
return visitor(field1_value, field2_value);
|
||||
},
|
||||
field2);
|
||||
},
|
||||
field1);
|
||||
}
|
||||
|
||||
|
||||
@ -473,8 +399,14 @@ private:
|
||||
public:
|
||||
explicit FieldVisitorSum(const Field & rhs_) : rhs(rhs_) {}
|
||||
|
||||
bool operator() (UInt64 & x) const { x += get<UInt64>(rhs); return x != 0; }
|
||||
bool operator() (Int64 & x) const { x += get<Int64>(rhs); return x != 0; }
|
||||
// We can add all ints as unsigned regardless of their actual signedness.
|
||||
bool operator() (Int64 & x) const { return this->operator()(reinterpret_cast<UInt64 &>(x)); }
|
||||
bool operator() (UInt64 & x) const
|
||||
{
|
||||
x += rhs.reinterpret<UInt64>();
|
||||
return x != 0;
|
||||
}
|
||||
|
||||
bool operator() (Float64 & x) const { x += get<Float64>(rhs); return x != 0; }
|
||||
|
||||
bool operator() (Null &) const { throw Exception("Cannot sum Nulls", ErrorCodes::LOGICAL_ERROR); }
|
||||
|
@ -84,6 +84,23 @@ struct DefaultHash<T, std::enable_if_t<is_arithmetic_v<T>>>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DefaultHash<T, std::enable_if_t<DB::IsDecimalNumber<T> && sizeof(T) <= 8>>
|
||||
{
|
||||
size_t operator() (T key) const
|
||||
{
|
||||
return DefaultHash64<typename T::NativeType>(key);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DefaultHash<T, std::enable_if_t<DB::IsDecimalNumber<T> && sizeof(T) == 16>>
|
||||
{
|
||||
size_t operator() (T key) const
|
||||
{
|
||||
return DefaultHash64<Int64>(key >> 64) ^ DefaultHash64<Int64>(key);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct HashCRC32;
|
||||
|
||||
|
162
dbms/src/Common/IntervalKind.cpp
Normal file
162
dbms/src/Common/IntervalKind.cpp
Normal file
@ -0,0 +1,162 @@
|
||||
#include <Common/IntervalKind.h>
|
||||
#include <Common/Exception.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int SYNTAX_ERROR;
|
||||
}
|
||||
|
||||
const char * IntervalKind::toString() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case IntervalKind::Second: return "Second";
|
||||
case IntervalKind::Minute: return "Minute";
|
||||
case IntervalKind::Hour: return "Hour";
|
||||
case IntervalKind::Day: return "Day";
|
||||
case IntervalKind::Week: return "Week";
|
||||
case IntervalKind::Month: return "Month";
|
||||
case IntervalKind::Quarter: return "Quarter";
|
||||
case IntervalKind::Year: return "Year";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
Int32 IntervalKind::toAvgSeconds() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case IntervalKind::Second: return 1;
|
||||
case IntervalKind::Minute: return 60;
|
||||
case IntervalKind::Hour: return 3600;
|
||||
case IntervalKind::Day: return 86400;
|
||||
case IntervalKind::Week: return 604800;
|
||||
case IntervalKind::Month: return 2629746; /// Exactly 1/12 of a year.
|
||||
case IntervalKind::Quarter: return 7889238; /// Exactly 1/4 of a year.
|
||||
case IntervalKind::Year: return 31556952; /// The average length of a Gregorian year is equal to 365.2425 days
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
IntervalKind IntervalKind::fromAvgSeconds(Int64 num_seconds)
|
||||
{
|
||||
if (num_seconds)
|
||||
{
|
||||
if (!(num_seconds % 31556952))
|
||||
return IntervalKind::Year;
|
||||
if (!(num_seconds % 7889238))
|
||||
return IntervalKind::Quarter;
|
||||
if (!(num_seconds % 604800))
|
||||
return IntervalKind::Week;
|
||||
if (!(num_seconds % 2629746))
|
||||
return IntervalKind::Month;
|
||||
if (!(num_seconds % 86400))
|
||||
return IntervalKind::Day;
|
||||
if (!(num_seconds % 3600))
|
||||
return IntervalKind::Hour;
|
||||
if (!(num_seconds % 60))
|
||||
return IntervalKind::Minute;
|
||||
}
|
||||
return IntervalKind::Second;
|
||||
}
|
||||
|
||||
|
||||
const char * IntervalKind::toKeyword() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case IntervalKind::Second: return "SECOND";
|
||||
case IntervalKind::Minute: return "MINUTE";
|
||||
case IntervalKind::Hour: return "HOUR";
|
||||
case IntervalKind::Day: return "DAY";
|
||||
case IntervalKind::Week: return "WEEK";
|
||||
case IntervalKind::Month: return "MONTH";
|
||||
case IntervalKind::Quarter: return "QUARTER";
|
||||
case IntervalKind::Year: return "YEAR";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
const char * IntervalKind::toDateDiffUnit() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case IntervalKind::Second:
|
||||
return "second";
|
||||
case IntervalKind::Minute:
|
||||
return "minute";
|
||||
case IntervalKind::Hour:
|
||||
return "hour";
|
||||
case IntervalKind::Day:
|
||||
return "day";
|
||||
case IntervalKind::Week:
|
||||
return "week";
|
||||
case IntervalKind::Month:
|
||||
return "month";
|
||||
case IntervalKind::Quarter:
|
||||
return "quarter";
|
||||
case IntervalKind::Year:
|
||||
return "year";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
const char * IntervalKind::toNameOfFunctionToIntervalDataType() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case IntervalKind::Second:
|
||||
return "toIntervalSecond";
|
||||
case IntervalKind::Minute:
|
||||
return "toIntervalMinute";
|
||||
case IntervalKind::Hour:
|
||||
return "toIntervalHour";
|
||||
case IntervalKind::Day:
|
||||
return "toIntervalDay";
|
||||
case IntervalKind::Week:
|
||||
return "toIntervalWeek";
|
||||
case IntervalKind::Month:
|
||||
return "toIntervalMonth";
|
||||
case IntervalKind::Quarter:
|
||||
return "toIntervalQuarter";
|
||||
case IntervalKind::Year:
|
||||
return "toIntervalYear";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
|
||||
const char * IntervalKind::toNameOfFunctionExtractTimePart() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case IntervalKind::Second:
|
||||
return "toSecond";
|
||||
case IntervalKind::Minute:
|
||||
return "toMinute";
|
||||
case IntervalKind::Hour:
|
||||
return "toHour";
|
||||
case IntervalKind::Day:
|
||||
return "toDayOfMonth";
|
||||
case IntervalKind::Week:
|
||||
// TODO: SELECT toRelativeWeekNum(toDate('2017-06-15')) - toRelativeWeekNum(toStartOfYear(toDate('2017-06-15')))
|
||||
// else if (ParserKeyword("WEEK").ignore(pos, expected))
|
||||
// function_name = "toRelativeWeekNum";
|
||||
throw Exception("The syntax 'EXTRACT(WEEK FROM date)' is not supported, cannot extract the number of a week", ErrorCodes::SYNTAX_ERROR);
|
||||
case IntervalKind::Month:
|
||||
return "toMonth";
|
||||
case IntervalKind::Quarter:
|
||||
return "toQuarter";
|
||||
case IntervalKind::Year:
|
||||
return "toYear";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
54
dbms/src/Common/IntervalKind.h
Normal file
54
dbms/src/Common/IntervalKind.h
Normal file
@ -0,0 +1,54 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Kind of a temporal interval.
|
||||
struct IntervalKind
|
||||
{
|
||||
enum Kind
|
||||
{
|
||||
Second,
|
||||
Minute,
|
||||
Hour,
|
||||
Day,
|
||||
Week,
|
||||
Month,
|
||||
Quarter,
|
||||
Year,
|
||||
};
|
||||
Kind kind = Second;
|
||||
|
||||
IntervalKind(Kind kind_ = Second) : kind(kind_) {}
|
||||
operator Kind() const { return kind; }
|
||||
|
||||
const char * toString() const;
|
||||
|
||||
/// Returns number of seconds in one interval.
|
||||
/// For `Month`, `Quarter` and `Year` the function returns an average number of seconds.
|
||||
Int32 toAvgSeconds() const;
|
||||
|
||||
/// Chooses an interval kind based on number of seconds.
|
||||
/// For example, `IntervalKind::fromAvgSeconds(3600)` returns `IntervalKind::Hour`.
|
||||
static IntervalKind fromAvgSeconds(Int64 num_seconds);
|
||||
|
||||
/// Returns an uppercased version of what `toString()` returns.
|
||||
const char * toKeyword() const;
|
||||
|
||||
/// Returns the string which can be passed to the `unit` parameter of the dateDiff() function.
|
||||
/// For example, `IntervalKind{IntervalKind::Day}.getDateDiffParameter()` returns "day".
|
||||
const char * toDateDiffUnit() const;
|
||||
|
||||
/// Returns the name of the function converting a number to the interval data type.
|
||||
/// For example, `IntervalKind{IntervalKind::Day}.getToIntervalDataTypeFunctionName()`
|
||||
/// returns "toIntervalDay".
|
||||
const char * toNameOfFunctionToIntervalDataType() const;
|
||||
|
||||
/// Returns the name of the function extracting time part from a date or a time.
|
||||
/// For example, `IntervalKind{IntervalKind::Day}.getExtractTimePartFunctionName()`
|
||||
/// returns "toDayOfMonth".
|
||||
const char * toNameOfFunctionExtractTimePart() const;
|
||||
};
|
||||
}
|
62
dbms/src/Common/RemoteHostFilter.cpp
Normal file
62
dbms/src/Common/RemoteHostFilter.cpp
Normal file
@ -0,0 +1,62 @@
|
||||
#include <re2/re2.h>
|
||||
#include <Common/RemoteHostFilter.h>
|
||||
#include <Poco/URI.h>
|
||||
#include <Formats/FormatFactory.h>
|
||||
#include <Poco/Util/AbstractConfiguration.h>
|
||||
#include <Common/StringUtils/StringUtils.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int UNACCEPTABLE_URL;
|
||||
}
|
||||
|
||||
void RemoteHostFilter::checkURL(const Poco::URI & uri) const
|
||||
{
|
||||
if (!checkForDirectEntry(uri.getHost()) &&
|
||||
!checkForDirectEntry(uri.getHost() + ":" + toString(uri.getPort())))
|
||||
throw Exception("URL \"" + uri.toString() + "\" is not allowed in config.xml", ErrorCodes::UNACCEPTABLE_URL);
|
||||
}
|
||||
|
||||
void RemoteHostFilter::checkHostAndPort(const std::string & host, const std::string & port) const
|
||||
{
|
||||
if (!checkForDirectEntry(host) &&
|
||||
!checkForDirectEntry(host + ":" + port))
|
||||
throw Exception("URL \"" + host + ":" + port + "\" is not allowed in config.xml", ErrorCodes::UNACCEPTABLE_URL);
|
||||
}
|
||||
|
||||
void RemoteHostFilter::setValuesFromConfig(const Poco::Util::AbstractConfiguration & config)
|
||||
{
|
||||
if (config.has("remote_url_allow_hosts"))
|
||||
{
|
||||
std::vector<std::string> keys;
|
||||
config.keys("remote_url_allow_hosts", keys);
|
||||
for (auto key : keys)
|
||||
{
|
||||
if (startsWith(key, "host_regexp"))
|
||||
regexp_hosts.push_back(config.getString("remote_url_allow_hosts." + key));
|
||||
else if (startsWith(key, "host"))
|
||||
primary_hosts.insert(config.getString("remote_url_allow_hosts." + key));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool RemoteHostFilter::checkForDirectEntry(const std::string & str) const
|
||||
{
|
||||
if (!primary_hosts.empty() || !regexp_hosts.empty())
|
||||
{
|
||||
if (primary_hosts.find(str) == primary_hosts.end())
|
||||
{
|
||||
for (size_t i = 0; i < regexp_hosts.size(); ++i)
|
||||
if (re2::RE2::FullMatch(str, regexp_hosts[i]))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
30
dbms/src/Common/RemoteHostFilter.h
Normal file
30
dbms/src/Common/RemoteHostFilter.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <Poco/URI.h>
|
||||
#include <Poco/Util/AbstractConfiguration.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class RemoteHostFilter
|
||||
{
|
||||
/**
|
||||
* This class checks if url is allowed.
|
||||
* If primary_hosts and regexp_hosts are empty all urls are allowed.
|
||||
*/
|
||||
public:
|
||||
void checkURL(const Poco::URI & uri) const; /// If URL not allowed in config.xml throw UNACCEPTABLE_URL Exception
|
||||
|
||||
void setValuesFromConfig(const Poco::Util::AbstractConfiguration & config);
|
||||
|
||||
void checkHostAndPort(const std::string & host, const std::string & port) const; /// Does the same as checkURL, but for host and port.
|
||||
|
||||
private:
|
||||
std::unordered_set<std::string> primary_hosts; /// Allowed primary (<host>) URL from config.xml
|
||||
std::vector<std::string> regexp_hosts; /// Allowed regexp (<hots_regexp>) URL from config.xml
|
||||
|
||||
bool checkForDirectEntry(const std::string & str) const; /// Checks if the primary_hosts and regexp_hosts contain str. If primary_hosts and regexp_hosts are empty return true.
|
||||
};
|
||||
}
|
@ -158,7 +158,7 @@ std::string signalToErrorMessage(int sig, const siginfo_t & info, const ucontext
|
||||
break;
|
||||
}
|
||||
|
||||
case SIGPROF:
|
||||
case SIGTSTP:
|
||||
{
|
||||
error << "This is a signal used for debugging purposes by the user.";
|
||||
break;
|
||||
|
@ -3,11 +3,18 @@
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include <Core/Defines.h>
|
||||
|
||||
// Also defined in Core/Defines.h
|
||||
#if !defined(NO_SANITIZE_UNDEFINED)
|
||||
#if defined(__clang__)
|
||||
#define NO_SANITIZE_UNDEFINED __attribute__((__no_sanitize__("undefined")))
|
||||
#else
|
||||
#define NO_SANITIZE_UNDEFINED
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
/// On overlow, the function returns unspecified value.
|
||||
|
||||
inline NO_SANITIZE_UNDEFINED uint64_t intExp2(int x)
|
||||
{
|
||||
return 1ULL << x;
|
||||
|
40
dbms/src/Common/malloc.cpp
Normal file
40
dbms/src/Common/malloc.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
#if defined(OS_LINUX)
|
||||
#include <stdlib.h>
|
||||
|
||||
/// Interposing these symbols explicitly. The idea works like this: malloc.cpp compiles to a
|
||||
/// dedicated object (namely clickhouse_malloc.o), and it will show earlier in the link command
|
||||
/// than malloc libs like libjemalloc.a. As a result, these symbols get picked in time right after.
|
||||
extern "C"
|
||||
{
|
||||
void *malloc(size_t size);
|
||||
void free(void *ptr);
|
||||
void *calloc(size_t nmemb, size_t size);
|
||||
void *realloc(void *ptr, size_t size);
|
||||
int posix_memalign(void **memptr, size_t alignment, size_t size);
|
||||
void *aligned_alloc(size_t alignment, size_t size);
|
||||
void *valloc(size_t size);
|
||||
void *memalign(size_t alignment, size_t size);
|
||||
void *pvalloc(size_t size);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void ignore(T x __attribute__((unused)))
|
||||
{
|
||||
}
|
||||
|
||||
static void dummyFunctionForInterposing() __attribute__((used));
|
||||
static void dummyFunctionForInterposing()
|
||||
{
|
||||
void* dummy;
|
||||
/// Suppression for PVS-Studio.
|
||||
free(nullptr); // -V575
|
||||
ignore(malloc(0)); // -V575
|
||||
ignore(calloc(0, 0)); // -V575
|
||||
ignore(realloc(nullptr, 0)); // -V575
|
||||
ignore(posix_memalign(&dummy, 0, 0)); // -V575
|
||||
ignore(aligned_alloc(0, 0)); // -V575
|
||||
ignore(valloc(0)); // -V575
|
||||
ignore(memalign(0, 0)); // -V575
|
||||
ignore(pvalloc(0)); // -V575
|
||||
}
|
||||
#endif
|
@ -3,8 +3,10 @@
|
||||
#include <type_traits>
|
||||
#include <typeinfo>
|
||||
#include <typeindex>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <ext/shared_ptr_helper.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <common/demangle.h>
|
||||
|
||||
@ -27,7 +29,7 @@ std::enable_if_t<std::is_reference_v<To>, To> typeid_cast(From & from)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (typeid(from) == typeid(To))
|
||||
if ((typeid(From) == typeid(To)) || (typeid(from) == typeid(To)))
|
||||
return static_cast<To>(from);
|
||||
}
|
||||
catch (const std::exception & e)
|
||||
@ -39,12 +41,13 @@ std::enable_if_t<std::is_reference_v<To>, To> typeid_cast(From & from)
|
||||
DB::ErrorCodes::BAD_CAST);
|
||||
}
|
||||
|
||||
|
||||
template <typename To, typename From>
|
||||
To typeid_cast(From * from)
|
||||
std::enable_if_t<std::is_pointer_v<To>, To> typeid_cast(From * from)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (typeid(*from) == typeid(std::remove_pointer_t<To>))
|
||||
if ((typeid(From) == typeid(std::remove_pointer_t<To>)) || (typeid(*from) == typeid(std::remove_pointer_t<To>)))
|
||||
return static_cast<To>(from);
|
||||
else
|
||||
return nullptr;
|
||||
@ -54,3 +57,20 @@ To typeid_cast(From * from)
|
||||
throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename To, typename From>
|
||||
std::enable_if_t<ext::is_shared_ptr_v<To>, To> typeid_cast(const std::shared_ptr<From> & from)
|
||||
{
|
||||
try
|
||||
{
|
||||
if ((typeid(From) == typeid(typename To::element_type)) || (typeid(*from) == typeid(typename To::element_type)))
|
||||
return std::static_pointer_cast<typename To::element_type>(from);
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
catch (const std::exception & e)
|
||||
{
|
||||
throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST);
|
||||
}
|
||||
}
|
||||
|
@ -88,9 +88,9 @@ public:
|
||||
|
||||
Shift shift;
|
||||
if (scale_a < scale_b)
|
||||
shift.a = DataTypeDecimal<B>(maxDecimalPrecision<B>(), scale_b).getScaleMultiplier(scale_b - scale_a);
|
||||
shift.a = B::getScaleMultiplier(scale_b - scale_a);
|
||||
if (scale_a > scale_b)
|
||||
shift.b = DataTypeDecimal<A>(maxDecimalPrecision<A>(), scale_a).getScaleMultiplier(scale_a - scale_b);
|
||||
shift.b = A::getScaleMultiplier(scale_a - scale_b);
|
||||
|
||||
return applyWithScale(a, b, shift);
|
||||
}
|
||||
|
@ -151,8 +151,8 @@
|
||||
#endif
|
||||
|
||||
/// Marks that extra information is sent to a shard. It could be any magic numbers.
|
||||
#define DBMS_DISTRIBUTED_SIGNATURE_EXTRA_INFO 0xCAFEDACEull
|
||||
#define DBMS_DISTRIBUTED_SIGNATURE_SETTINGS_OLD_FORMAT 0xCAFECABEull
|
||||
#define DBMS_DISTRIBUTED_SIGNATURE_HEADER 0xCAFEDACEull
|
||||
#define DBMS_DISTRIBUTED_SIGNATURE_HEADER_OLD_FORMAT 0xCAFECABEull
|
||||
|
||||
#if !__has_include(<sanitizer/asan_interface.h>)
|
||||
# define ASAN_UNPOISON_MEMORY_REGION(a, b)
|
||||
|
@ -295,26 +295,11 @@ namespace DB
|
||||
|
||||
void writeFieldText(const Field & x, WriteBuffer & buf)
|
||||
{
|
||||
DB::String res = applyVisitor(DB::FieldVisitorToString(), x);
|
||||
DB::String res = Field::dispatch(DB::FieldVisitorToString(), x);
|
||||
buf.write(res.data(), res.size());
|
||||
}
|
||||
|
||||
|
||||
template <> Decimal32 DecimalField<Decimal32>::getScaleMultiplier() const
|
||||
{
|
||||
return DataTypeDecimal<Decimal32>::getScaleMultiplier(scale);
|
||||
}
|
||||
|
||||
template <> Decimal64 DecimalField<Decimal64>::getScaleMultiplier() const
|
||||
{
|
||||
return DataTypeDecimal<Decimal64>::getScaleMultiplier(scale);
|
||||
}
|
||||
|
||||
template <> Decimal128 DecimalField<Decimal128>::getScaleMultiplier() const
|
||||
{
|
||||
return DataTypeDecimal<Decimal128>::getScaleMultiplier(scale);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool decEqual(T x, T y, UInt32 x_scale, UInt32 y_scale)
|
||||
{
|
||||
|
@ -27,7 +27,7 @@ namespace ErrorCodes
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct NearestFieldTypeImpl;
|
||||
|
||||
template <typename T>
|
||||
@ -102,7 +102,7 @@ public:
|
||||
|
||||
operator T() const { return dec; }
|
||||
T getValue() const { return dec; }
|
||||
T getScaleMultiplier() const;
|
||||
T getScaleMultiplier() const { return T::getScaleMultiplier(scale); }
|
||||
UInt32 getScale() const { return scale; }
|
||||
|
||||
template <typename U>
|
||||
@ -151,6 +151,54 @@ private:
|
||||
UInt32 scale;
|
||||
};
|
||||
|
||||
/// char may be signed or unsigned, and behave identically to signed char or unsigned char,
|
||||
/// but they are always three different types.
|
||||
/// signedness of char is different in Linux on x86 and Linux on ARM.
|
||||
template <> struct NearestFieldTypeImpl<char> { using Type = std::conditional_t<is_signed_v<char>, Int64, UInt64>; };
|
||||
template <> struct NearestFieldTypeImpl<signed char> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<unsigned char> { using Type = UInt64; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<UInt16> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<UInt32> { using Type = UInt64; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<DayNum> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<UInt128> { using Type = UInt128; };
|
||||
template <> struct NearestFieldTypeImpl<UUID> { using Type = UInt128; };
|
||||
template <> struct NearestFieldTypeImpl<Int16> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<Int32> { using Type = Int64; };
|
||||
|
||||
/// long and long long are always different types that may behave identically or not.
|
||||
/// This is different on Linux and Mac.
|
||||
template <> struct NearestFieldTypeImpl<long> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<long long> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<unsigned long> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<unsigned long long> { using Type = UInt64; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<Int128> { using Type = Int128; };
|
||||
template <> struct NearestFieldTypeImpl<Decimal32> { using Type = DecimalField<Decimal32>; };
|
||||
template <> struct NearestFieldTypeImpl<Decimal64> { using Type = DecimalField<Decimal64>; };
|
||||
template <> struct NearestFieldTypeImpl<Decimal128> { using Type = DecimalField<Decimal128>; };
|
||||
template <> struct NearestFieldTypeImpl<DecimalField<Decimal32>> { using Type = DecimalField<Decimal32>; };
|
||||
template <> struct NearestFieldTypeImpl<DecimalField<Decimal64>> { using Type = DecimalField<Decimal64>; };
|
||||
template <> struct NearestFieldTypeImpl<DecimalField<Decimal128>> { using Type = DecimalField<Decimal128>; };
|
||||
template <> struct NearestFieldTypeImpl<Float32> { using Type = Float64; };
|
||||
template <> struct NearestFieldTypeImpl<Float64> { using Type = Float64; };
|
||||
template <> struct NearestFieldTypeImpl<const char *> { using Type = String; };
|
||||
template <> struct NearestFieldTypeImpl<String> { using Type = String; };
|
||||
template <> struct NearestFieldTypeImpl<Array> { using Type = Array; };
|
||||
template <> struct NearestFieldTypeImpl<Tuple> { using Type = Tuple; };
|
||||
template <> struct NearestFieldTypeImpl<bool> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<Null> { using Type = Null; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<AggregateFunctionStateData> { using Type = AggregateFunctionStateData; };
|
||||
|
||||
// For enum types, use the field type that corresponds to their underlying type.
|
||||
template <typename T>
|
||||
struct NearestFieldTypeImpl<T, std::enable_if_t<std::is_enum_v<T>>>
|
||||
{
|
||||
using Type = NearestFieldType<std::underlying_type_t<T>>;
|
||||
};
|
||||
|
||||
/** 32 is enough. Round number is used for alignment and for better arithmetic inside std::vector.
|
||||
* NOTE: Actually, sizeof(std::string) is 32 when using libc++, so Field is 40 bytes.
|
||||
*/
|
||||
@ -314,18 +362,24 @@ public:
|
||||
bool isNull() const { return which == Types::Null; }
|
||||
|
||||
|
||||
template <typename T> T & get()
|
||||
template <typename T>
|
||||
T & get();
|
||||
|
||||
template <typename T>
|
||||
const T & get() const
|
||||
{
|
||||
using TWithoutRef = std::remove_reference_t<T>;
|
||||
TWithoutRef * MAY_ALIAS ptr = reinterpret_cast<TWithoutRef*>(&storage);
|
||||
return *ptr;
|
||||
auto mutable_this = const_cast<std::decay_t<decltype(*this)> *>(this);
|
||||
return mutable_this->get<T>();
|
||||
}
|
||||
|
||||
template <typename T> const T & get() const
|
||||
template <typename T>
|
||||
T & reinterpret();
|
||||
|
||||
template <typename T>
|
||||
const T & reinterpret() const
|
||||
{
|
||||
using TWithoutRef = std::remove_reference_t<T>;
|
||||
const TWithoutRef * MAY_ALIAS ptr = reinterpret_cast<const TWithoutRef*>(&storage);
|
||||
return *ptr;
|
||||
auto mutable_this = const_cast<std::decay_t<decltype(*this)> *>(this);
|
||||
return mutable_this->reinterpret<T>();
|
||||
}
|
||||
|
||||
template <typename T> bool tryGet(T & result)
|
||||
@ -427,6 +481,8 @@ public:
|
||||
return rhs <= *this;
|
||||
}
|
||||
|
||||
// More like bitwise equality as opposed to semantic equality:
|
||||
// Null equals Null and NaN equals NaN.
|
||||
bool operator== (const Field & rhs) const
|
||||
{
|
||||
if (which != rhs.which)
|
||||
@ -435,9 +491,13 @@ public:
|
||||
switch (which)
|
||||
{
|
||||
case Types::Null: return true;
|
||||
case Types::UInt64:
|
||||
case Types::Int64:
|
||||
case Types::Float64: return get<UInt64>() == rhs.get<UInt64>();
|
||||
case Types::UInt64: return get<UInt64>() == rhs.get<UInt64>();
|
||||
case Types::Int64: return get<Int64>() == rhs.get<Int64>();
|
||||
case Types::Float64:
|
||||
{
|
||||
// Compare as UInt64 so that NaNs compare as equal.
|
||||
return reinterpret<UInt64>() == rhs.reinterpret<UInt64>();
|
||||
}
|
||||
case Types::String: return get<String>() == rhs.get<String>();
|
||||
case Types::Array: return get<Array>() == rhs.get<Array>();
|
||||
case Types::Tuple: return get<Tuple>() == rhs.get<Tuple>();
|
||||
@ -457,6 +517,42 @@ public:
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
/// Field is template parameter, to allow universal reference for field,
|
||||
/// that is useful for const and non-const .
|
||||
template <typename F, typename FieldRef>
|
||||
static auto dispatch(F && f, FieldRef && field)
|
||||
{
|
||||
switch (field.which)
|
||||
{
|
||||
case Types::Null: return f(field.template get<Null>());
|
||||
case Types::UInt64: return f(field.template get<UInt64>());
|
||||
case Types::UInt128: return f(field.template get<UInt128>());
|
||||
case Types::Int64: return f(field.template get<Int64>());
|
||||
case Types::Float64: return f(field.template get<Float64>());
|
||||
case Types::String: return f(field.template get<String>());
|
||||
case Types::Array: return f(field.template get<Array>());
|
||||
case Types::Tuple: return f(field.template get<Tuple>());
|
||||
case Types::Decimal32: return f(field.template get<DecimalField<Decimal32>>());
|
||||
case Types::Decimal64: return f(field.template get<DecimalField<Decimal64>>());
|
||||
case Types::Decimal128: return f(field.template get<DecimalField<Decimal128>>());
|
||||
case Types::AggregateFunctionState: return f(field.template get<AggregateFunctionStateData>());
|
||||
case Types::Int128:
|
||||
// TODO: investigate where we need Int128 Fields. There are no
|
||||
// field visitors that support them, and they only arise indirectly
|
||||
// in some functions that use Decimal columns: they get the
|
||||
// underlying Field value with get<Int128>(). Probably should be
|
||||
// switched to DecimalField, but this is a whole endeavor in itself.
|
||||
throw Exception("Unexpected Int128 in Field::dispatch()", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
// GCC 9 complains that control reaches the end, despite that we handle
|
||||
// all the cases above (maybe because of throw?). Return something to
|
||||
// silence it.
|
||||
Null null{};
|
||||
return f(null);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
std::aligned_union_t<DBMS_MIN_FIELD_SIZE - sizeof(Types::Which),
|
||||
Null, UInt64, UInt128, Int64, Int128, Float64, String, Array, Tuple,
|
||||
@ -493,37 +589,6 @@ private:
|
||||
}
|
||||
|
||||
|
||||
template <typename F, typename Field> /// Field template parameter may be const or non-const Field.
|
||||
static void dispatch(F && f, Field & field)
|
||||
{
|
||||
switch (field.which)
|
||||
{
|
||||
case Types::Null: f(field.template get<Null>()); return;
|
||||
|
||||
// gcc 7.3.0
|
||||
#if !__clang__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
||||
#endif
|
||||
case Types::UInt64: f(field.template get<UInt64>()); return;
|
||||
case Types::UInt128: f(field.template get<UInt128>()); return;
|
||||
case Types::Int64: f(field.template get<Int64>()); return;
|
||||
case Types::Int128: f(field.template get<Int128>()); return;
|
||||
case Types::Float64: f(field.template get<Float64>()); return;
|
||||
#if !__clang__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
case Types::String: f(field.template get<String>()); return;
|
||||
case Types::Array: f(field.template get<Array>()); return;
|
||||
case Types::Tuple: f(field.template get<Tuple>()); return;
|
||||
case Types::Decimal32: f(field.template get<DecimalField<Decimal32>>()); return;
|
||||
case Types::Decimal64: f(field.template get<DecimalField<Decimal64>>()); return;
|
||||
case Types::Decimal128: f(field.template get<DecimalField<Decimal128>>()); return;
|
||||
case Types::AggregateFunctionState: f(field.template get<AggregateFunctionStateData>()); return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void create(const Field & x)
|
||||
{
|
||||
dispatch([this] (auto & value) { createConcrete(value); }, x);
|
||||
@ -621,6 +686,22 @@ template <> struct Field::EnumToType<Field::Types::Decimal64> { using Type = Dec
|
||||
template <> struct Field::EnumToType<Field::Types::Decimal128> { using Type = DecimalField<Decimal128>; };
|
||||
template <> struct Field::EnumToType<Field::Types::AggregateFunctionState> { using Type = DecimalField<AggregateFunctionStateData>; };
|
||||
|
||||
template <typename T>
|
||||
T & Field::get()
|
||||
{
|
||||
using ValueType = std::decay_t<T>;
|
||||
//assert(TypeToEnum<NearestFieldType<ValueType>>::value == which);
|
||||
ValueType * MAY_ALIAS ptr = reinterpret_cast<ValueType *>(&storage);
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T & Field::reinterpret()
|
||||
{
|
||||
using ValueType = std::decay_t<T>;
|
||||
ValueType * MAY_ALIAS ptr = reinterpret_cast<ValueType *>(&storage);
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get(const Field & field)
|
||||
@ -651,49 +732,6 @@ template <> struct TypeName<Array> { static std::string get() { return "Array";
|
||||
template <> struct TypeName<Tuple> { static std::string get() { return "Tuple"; } };
|
||||
template <> struct TypeName<AggregateFunctionStateData> { static std::string get() { return "AggregateFunctionState"; } };
|
||||
|
||||
|
||||
|
||||
/// char may be signed or unsigned, and behave identically to signed char or unsigned char,
|
||||
/// but they are always three different types.
|
||||
/// signedness of char is different in Linux on x86 and Linux on ARM.
|
||||
template <> struct NearestFieldTypeImpl<char> { using Type = std::conditional_t<is_signed_v<char>, Int64, UInt64>; };
|
||||
template <> struct NearestFieldTypeImpl<signed char> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<unsigned char> { using Type = UInt64; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<UInt16> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<UInt32> { using Type = UInt64; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<DayNum> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<UInt128> { using Type = UInt128; };
|
||||
template <> struct NearestFieldTypeImpl<UUID> { using Type = UInt128; };
|
||||
template <> struct NearestFieldTypeImpl<Int16> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<Int32> { using Type = Int64; };
|
||||
|
||||
/// long and long long are always different types that may behave identically or not.
|
||||
/// This is different on Linux and Mac.
|
||||
template <> struct NearestFieldTypeImpl<long> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<long long> { using Type = Int64; };
|
||||
template <> struct NearestFieldTypeImpl<unsigned long> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<unsigned long long> { using Type = UInt64; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<Int128> { using Type = Int128; };
|
||||
template <> struct NearestFieldTypeImpl<Decimal32> { using Type = DecimalField<Decimal32>; };
|
||||
template <> struct NearestFieldTypeImpl<Decimal64> { using Type = DecimalField<Decimal64>; };
|
||||
template <> struct NearestFieldTypeImpl<Decimal128> { using Type = DecimalField<Decimal128>; };
|
||||
template <> struct NearestFieldTypeImpl<DecimalField<Decimal32>> { using Type = DecimalField<Decimal32>; };
|
||||
template <> struct NearestFieldTypeImpl<DecimalField<Decimal64>> { using Type = DecimalField<Decimal64>; };
|
||||
template <> struct NearestFieldTypeImpl<DecimalField<Decimal128>> { using Type = DecimalField<Decimal128>; };
|
||||
template <> struct NearestFieldTypeImpl<Float32> { using Type = Float64; };
|
||||
template <> struct NearestFieldTypeImpl<Float64> { using Type = Float64; };
|
||||
template <> struct NearestFieldTypeImpl<const char *> { using Type = String; };
|
||||
template <> struct NearestFieldTypeImpl<String> { using Type = String; };
|
||||
template <> struct NearestFieldTypeImpl<Array> { using Type = Array; };
|
||||
template <> struct NearestFieldTypeImpl<Tuple> { using Type = Tuple; };
|
||||
template <> struct NearestFieldTypeImpl<bool> { using Type = UInt64; };
|
||||
template <> struct NearestFieldTypeImpl<Null> { using Type = Null; };
|
||||
|
||||
template <> struct NearestFieldTypeImpl<AggregateFunctionStateData> { using Type = AggregateFunctionStateData; };
|
||||
|
||||
template <typename T>
|
||||
decltype(auto) castToNearestFieldType(T && x)
|
||||
{
|
||||
|
@ -100,4 +100,69 @@ size_t getLengthEncodedStringSize(const String & s)
|
||||
return getLengthEncodedNumberSize(s.size()) + s.size();
|
||||
}
|
||||
|
||||
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
|
||||
{
|
||||
ColumnType column_type;
|
||||
int flags = 0;
|
||||
switch (type_index)
|
||||
{
|
||||
case TypeIndex::UInt8:
|
||||
column_type = ColumnType::MYSQL_TYPE_TINY;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
|
||||
break;
|
||||
case TypeIndex::UInt16:
|
||||
column_type = ColumnType::MYSQL_TYPE_SHORT;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
|
||||
break;
|
||||
case TypeIndex::UInt32:
|
||||
column_type = ColumnType::MYSQL_TYPE_LONG;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
|
||||
break;
|
||||
case TypeIndex::UInt64:
|
||||
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
|
||||
break;
|
||||
case TypeIndex::Int8:
|
||||
column_type = ColumnType::MYSQL_TYPE_TINY;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::Int16:
|
||||
column_type = ColumnType::MYSQL_TYPE_SHORT;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::Int32:
|
||||
column_type = ColumnType::MYSQL_TYPE_LONG;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::Int64:
|
||||
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::Float32:
|
||||
column_type = ColumnType::MYSQL_TYPE_FLOAT;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::Float64:
|
||||
column_type = ColumnType::MYSQL_TYPE_DOUBLE;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::Date:
|
||||
column_type = ColumnType::MYSQL_TYPE_DATE;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::DateTime:
|
||||
column_type = ColumnType::MYSQL_TYPE_DATETIME;
|
||||
flags = ColumnDefinitionFlags::BINARY_FLAG;
|
||||
break;
|
||||
case TypeIndex::String:
|
||||
case TypeIndex::FixedString:
|
||||
column_type = ColumnType::MYSQL_TYPE_STRING;
|
||||
break;
|
||||
default:
|
||||
column_type = ColumnType::MYSQL_TYPE_STRING;
|
||||
break;
|
||||
}
|
||||
return ColumnDefinition(column_name, CharacterSet::binary, 0, column_type, flags, 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -130,6 +130,14 @@ enum ColumnType
|
||||
};
|
||||
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
|
||||
enum ColumnDefinitionFlags
|
||||
{
|
||||
UNSIGNED_FLAG = 32,
|
||||
BINARY_FLAG = 128
|
||||
};
|
||||
|
||||
|
||||
class ProtocolError : public DB::Exception
|
||||
{
|
||||
public:
|
||||
@ -824,19 +832,40 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index);
|
||||
|
||||
|
||||
namespace ProtocolText
|
||||
{
|
||||
|
||||
class ResultsetRow : public WritePacket
|
||||
{
|
||||
std::vector<String> columns;
|
||||
const Columns & columns;
|
||||
int row_num;
|
||||
size_t payload_size = 0;
|
||||
std::vector<String> serialized;
|
||||
public:
|
||||
ResultsetRow() = default;
|
||||
|
||||
void appendColumn(String && value)
|
||||
ResultsetRow(const DataTypes & data_types, const Columns & columns_, int row_num_)
|
||||
: columns(columns_)
|
||||
, row_num(row_num_)
|
||||
{
|
||||
payload_size += getLengthEncodedStringSize(value);
|
||||
columns.emplace_back(std::move(value));
|
||||
for (size_t i = 0; i < columns.size(); i++)
|
||||
{
|
||||
if (columns[i]->isNullAt(row_num))
|
||||
{
|
||||
payload_size += 1;
|
||||
serialized.emplace_back("\xfb");
|
||||
}
|
||||
else
|
||||
{
|
||||
WriteBufferFromOwnString ostr;
|
||||
data_types[i]->serializeAsText(*columns[i], row_num, ostr, FormatSettings());
|
||||
payload_size += getLengthEncodedStringSize(ostr.str());
|
||||
serialized.push_back(std::move(ostr.str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
@ -845,11 +874,18 @@ protected:
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
for (const String & column : columns)
|
||||
writeLengthEncodedString(column, buffer);
|
||||
for (size_t i = 0; i < columns.size(); i++)
|
||||
{
|
||||
if (columns[i]->isNullAt(row_num))
|
||||
buffer.write(serialized[i].data(), 1);
|
||||
else
|
||||
writeLengthEncodedString(serialized[i], buffer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace Authentication
|
||||
{
|
||||
|
||||
@ -917,10 +953,7 @@ public:
|
||||
|
||||
auto user = context.getUser(user_name);
|
||||
|
||||
if (user->authentication.getType() != DB::Authentication::DOUBLE_SHA1_PASSWORD)
|
||||
throw Exception("Cannot use " + getName() + " auth plugin for user " + user_name + " since its password isn't specified using double SHA1.", ErrorCodes::UNKNOWN_EXCEPTION);
|
||||
|
||||
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordHashBinary();
|
||||
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
|
||||
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);
|
||||
|
||||
Poco::SHA1Engine engine;
|
||||
|
@ -62,7 +62,7 @@ void SettingNumber<Type>::set(const Field & x)
|
||||
template <typename Type>
|
||||
void SettingNumber<Type>::set(const String & x)
|
||||
{
|
||||
set(parse<Type>(x));
|
||||
set(completeParse<Type>(x));
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -5,6 +5,9 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
using TypeListNumbers = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64>;
|
||||
using TypeListNativeNumbers = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64>;
|
||||
using TypeListDecimalNumbers = TypeList<Decimal32, Decimal64, Decimal128>;
|
||||
using TypeListNumbers = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64,
|
||||
Decimal32, Decimal64, Decimal128>;
|
||||
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <common/Types.h>
|
||||
#include <Common/intExp.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -145,6 +146,8 @@ struct Decimal
|
||||
const Decimal<T> & operator /= (const T & x) { value /= x; return *this; }
|
||||
const Decimal<T> & operator %= (const T & x) { value %= x; return *this; }
|
||||
|
||||
static T getScaleMultiplier(UInt32 scale);
|
||||
|
||||
T value;
|
||||
};
|
||||
|
||||
@ -170,6 +173,10 @@ template <> struct NativeType<Decimal32> { using Type = Int32; };
|
||||
template <> struct NativeType<Decimal64> { using Type = Int64; };
|
||||
template <> struct NativeType<Decimal128> { using Type = Int128; };
|
||||
|
||||
template <> inline Int32 Decimal32::getScaleMultiplier(UInt32 scale) { return common::exp10_i32(scale); }
|
||||
template <> inline Int64 Decimal64::getScaleMultiplier(UInt32 scale) { return common::exp10_i64(scale); }
|
||||
template <> inline Int128 Decimal128::getScaleMultiplier(UInt32 scale) { return common::exp10_i128(scale); }
|
||||
|
||||
inline const char * getTypeName(TypeIndex idx)
|
||||
{
|
||||
switch (idx)
|
||||
|
@ -52,7 +52,7 @@ struct Less
|
||||
{
|
||||
for (auto it = left_columns.begin(), jt = right_columns.begin(); it != left_columns.end(); ++it, ++jt)
|
||||
{
|
||||
int res = it->second.direction * it->first->compareAt(a, b, *jt->first, it->second.nulls_direction);
|
||||
int res = it->description.direction * it->column->compareAt(a, b, *jt->column, it->description.nulls_direction);
|
||||
if (res < 0)
|
||||
return true;
|
||||
else if (res > 0)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <Core/Field.h>
|
||||
#include <Interpreters/ProcessList.h>
|
||||
#include <Interpreters/Quota.h>
|
||||
#include <Access/QuotaContext.h>
|
||||
#include <Common/CurrentThread.h>
|
||||
#include <common/sleep.h>
|
||||
|
||||
@ -70,7 +70,7 @@ Block IBlockInputStream::read()
|
||||
if (limits.mode == LIMITS_CURRENT && !limits.size_limits.check(info.rows, info.bytes, "result", ErrorCodes::TOO_MANY_ROWS_OR_BYTES))
|
||||
limit_exceeded_need_break = true;
|
||||
|
||||
if (quota != nullptr)
|
||||
if (quota)
|
||||
checkQuota(res);
|
||||
}
|
||||
else
|
||||
@ -240,12 +240,8 @@ void IBlockInputStream::checkQuota(Block & block)
|
||||
|
||||
case LIMITS_CURRENT:
|
||||
{
|
||||
time_t current_time = time(nullptr);
|
||||
double total_elapsed = info.total_stopwatch.elapsedSeconds();
|
||||
|
||||
quota->checkAndAddResultRowsBytes(current_time, block.rows(), block.bytes());
|
||||
quota->checkAndAddExecutionTime(current_time, Poco::Timespan((total_elapsed - prev_elapsed) * 1000000.0));
|
||||
|
||||
UInt64 total_elapsed = info.total_stopwatch.elapsedNanoseconds();
|
||||
quota->used({Quota::RESULT_ROWS, block.rows()}, {Quota::RESULT_BYTES, block.bytes()}, {Quota::EXECUTION_TIME, total_elapsed - prev_elapsed});
|
||||
prev_elapsed = total_elapsed;
|
||||
break;
|
||||
}
|
||||
@ -291,10 +287,8 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
|
||||
limits.speed_limits.throttle(progress.read_rows, progress.read_bytes, total_rows, total_elapsed_microseconds);
|
||||
|
||||
if (quota != nullptr && limits.mode == LIMITS_TOTAL)
|
||||
{
|
||||
quota->checkAndAddReadRowsBytes(time(nullptr), value.read_rows, value.read_bytes);
|
||||
}
|
||||
if (quota && limits.mode == LIMITS_TOTAL)
|
||||
quota->used({Quota::READ_ROWS, value.read_rows}, {Quota::READ_BYTES, value.read_bytes});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace ErrorCodes
|
||||
}
|
||||
|
||||
class ProcessListElement;
|
||||
class QuotaForIntervals;
|
||||
class QuotaContext;
|
||||
class QueryStatus;
|
||||
struct SortColumnDescription;
|
||||
using SortDescription = std::vector<SortColumnDescription>;
|
||||
@ -220,9 +220,9 @@ public:
|
||||
/** Set the quota. If you set a quota on the amount of raw data,
|
||||
* then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits.
|
||||
*/
|
||||
virtual void setQuota(QuotaForIntervals & quota_)
|
||||
virtual void setQuota(const std::shared_ptr<QuotaContext> & quota_)
|
||||
{
|
||||
quota = "a_;
|
||||
quota = quota_;
|
||||
}
|
||||
|
||||
/// Enable calculation of minimums and maximums by the result columns.
|
||||
@ -263,6 +263,11 @@ protected:
|
||||
*/
|
||||
bool checkTimeLimit();
|
||||
|
||||
#ifndef NDEBUG
|
||||
bool read_prefix_is_called = false;
|
||||
bool read_suffix_is_called = false;
|
||||
#endif
|
||||
|
||||
private:
|
||||
bool enabled_extremes = false;
|
||||
|
||||
@ -273,8 +278,8 @@ private:
|
||||
|
||||
LocalLimits limits;
|
||||
|
||||
QuotaForIntervals * quota = nullptr; /// If nullptr - the quota is not used.
|
||||
double prev_elapsed = 0;
|
||||
std::shared_ptr<QuotaContext> quota; /// If nullptr - the quota is not used.
|
||||
UInt64 prev_elapsed = 0;
|
||||
|
||||
/// The approximate total number of rows to read. For progress bar.
|
||||
size_t total_rows_approx = 0;
|
||||
@ -315,10 +320,6 @@ private:
|
||||
return;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
bool read_prefix_is_called = false;
|
||||
bool read_suffix_is_called = false;
|
||||
#endif
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -57,6 +57,20 @@ NativeBlockInputStream::NativeBlockInputStream(ReadBuffer & istr_, UInt64 server
|
||||
}
|
||||
}
|
||||
|
||||
// also resets few vars from IBlockInputStream (I didn't want to propagate resetParser upthere)
|
||||
void NativeBlockInputStream::resetParser()
|
||||
{
|
||||
istr_concrete = nullptr;
|
||||
use_index = false;
|
||||
|
||||
#ifndef NDEBUG
|
||||
read_prefix_is_called = false;
|
||||
read_suffix_is_called = false;
|
||||
#endif
|
||||
|
||||
is_cancelled.store(false);
|
||||
is_killed.store(false);
|
||||
}
|
||||
|
||||
void NativeBlockInputStream::readData(const IDataType & type, IColumn & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint)
|
||||
{
|
||||
|
@ -78,6 +78,9 @@ public:
|
||||
|
||||
Block getHeader() const override;
|
||||
|
||||
void resetParser();
|
||||
|
||||
|
||||
protected:
|
||||
Block readImpl() override;
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <DataStreams/ParallelParsingBlockInputStream.h>
|
||||
#include "ParallelParsingBlockInputStream.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -15,7 +14,7 @@ void ParallelParsingBlockInputStream::segmentatorThreadFunction()
|
||||
auto & unit = processing_units[current_unit_number];
|
||||
|
||||
{
|
||||
std::unique_lock lock(mutex);
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
segmentator_condvar.wait(lock,
|
||||
[&]{ return unit.status == READY_TO_INSERT || finished; });
|
||||
}
|
||||
@ -85,7 +84,7 @@ void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_unit_n
|
||||
// except at the end of file. Also see a matching assert in readImpl().
|
||||
assert(unit.is_last || unit.block_ext.block.size() > 0);
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
unit.status = READY_TO_READ;
|
||||
reader_condvar.notify_all();
|
||||
}
|
||||
@ -99,7 +98,7 @@ void ParallelParsingBlockInputStream::onBackgroundException()
|
||||
{
|
||||
tryLogCurrentException(__PRETTY_FUNCTION__);
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (!background_exception)
|
||||
{
|
||||
background_exception = std::current_exception();
|
||||
@ -116,7 +115,7 @@ Block ParallelParsingBlockInputStream::readImpl()
|
||||
/**
|
||||
* Check for background exception and rethrow it before we return.
|
||||
*/
|
||||
std::unique_lock lock(mutex);
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (background_exception)
|
||||
{
|
||||
lock.unlock();
|
||||
@ -134,7 +133,7 @@ Block ParallelParsingBlockInputStream::readImpl()
|
||||
{
|
||||
// We have read out all the Blocks from the previous Processing Unit,
|
||||
// wait for the current one to become ready.
|
||||
std::unique_lock lock(mutex);
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
reader_condvar.wait(lock, [&](){ return unit.status == READY_TO_READ || finished; });
|
||||
|
||||
if (finished)
|
||||
@ -190,7 +189,7 @@ Block ParallelParsingBlockInputStream::readImpl()
|
||||
else
|
||||
{
|
||||
// Pass the unit back to the segmentator.
|
||||
std::unique_lock lock(mutex);
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
unit.status = READY_TO_INSERT;
|
||||
segmentator_condvar.notify_all();
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ private:
|
||||
finished = true;
|
||||
|
||||
{
|
||||
std::unique_lock lock(mutex);
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
segmentator_condvar.notify_all();
|
||||
reader_condvar.notify_all();
|
||||
}
|
||||
@ -255,4 +255,4 @@ private:
|
||||
void onBackgroundException();
|
||||
};
|
||||
|
||||
};
|
||||
}
|
||||
|
@ -78,7 +78,9 @@ SummingSortedBlockInputStream::SummingSortedBlockInputStream(
|
||||
else
|
||||
{
|
||||
bool is_agg_func = WhichDataType(column.type).isAggregateFunction();
|
||||
if (!column.type->isSummable() && !is_agg_func)
|
||||
|
||||
/// There are special const columns for example after prewere sections.
|
||||
if ((!column.type->isSummable() && !is_agg_func) || isColumnConst(*column.column))
|
||||
{
|
||||
column_numbers_not_to_aggregate.push_back(i);
|
||||
continue;
|
||||
@ -198,6 +200,10 @@ SummingSortedBlockInputStream::SummingSortedBlockInputStream(
|
||||
|
||||
void SummingSortedBlockInputStream::insertCurrentRowIfNeeded(MutableColumns & merged_columns)
|
||||
{
|
||||
/// We have nothing to aggregate. It means that it could be non-zero, because we have columns_not_to_aggregate.
|
||||
if (columns_to_aggregate.empty())
|
||||
current_row_is_zero = false;
|
||||
|
||||
for (auto & desc : columns_to_aggregate)
|
||||
{
|
||||
// Do not insert if the aggregation state hasn't been created
|
||||
|
@ -13,14 +13,14 @@ bool DataTypeInterval::equals(const IDataType & rhs) const
|
||||
|
||||
void registerDataTypeInterval(DataTypeFactory & factory)
|
||||
{
|
||||
factory.registerSimpleDataType("IntervalSecond", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Second)); });
|
||||
factory.registerSimpleDataType("IntervalMinute", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Minute)); });
|
||||
factory.registerSimpleDataType("IntervalHour", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Hour)); });
|
||||
factory.registerSimpleDataType("IntervalDay", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Day)); });
|
||||
factory.registerSimpleDataType("IntervalWeek", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Week)); });
|
||||
factory.registerSimpleDataType("IntervalMonth", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Month)); });
|
||||
factory.registerSimpleDataType("IntervalQuarter", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Quarter)); });
|
||||
factory.registerSimpleDataType("IntervalYear", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Year)); });
|
||||
factory.registerSimpleDataType("IntervalSecond", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Second)); });
|
||||
factory.registerSimpleDataType("IntervalMinute", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Minute)); });
|
||||
factory.registerSimpleDataType("IntervalHour", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Hour)); });
|
||||
factory.registerSimpleDataType("IntervalDay", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Day)); });
|
||||
factory.registerSimpleDataType("IntervalWeek", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Week)); });
|
||||
factory.registerSimpleDataType("IntervalMonth", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Month)); });
|
||||
factory.registerSimpleDataType("IntervalQuarter", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Quarter)); });
|
||||
factory.registerSimpleDataType("IntervalYear", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Year)); });
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <DataTypes/DataTypeNumberBase.h>
|
||||
#include <Common/IntervalKind.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -16,47 +17,17 @@ namespace DB
|
||||
*/
|
||||
class DataTypeInterval final : public DataTypeNumberBase<Int64>
|
||||
{
|
||||
public:
|
||||
enum Kind
|
||||
{
|
||||
Second,
|
||||
Minute,
|
||||
Hour,
|
||||
Day,
|
||||
Week,
|
||||
Month,
|
||||
Quarter,
|
||||
Year
|
||||
};
|
||||
|
||||
private:
|
||||
Kind kind;
|
||||
IntervalKind kind;
|
||||
|
||||
public:
|
||||
static constexpr bool is_parametric = true;
|
||||
|
||||
Kind getKind() const { return kind; }
|
||||
IntervalKind getKind() const { return kind; }
|
||||
|
||||
const char * kindToString() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case Second: return "Second";
|
||||
case Minute: return "Minute";
|
||||
case Hour: return "Hour";
|
||||
case Day: return "Day";
|
||||
case Week: return "Week";
|
||||
case Month: return "Month";
|
||||
case Quarter: return "Quarter";
|
||||
case Year: return "Year";
|
||||
}
|
||||
DataTypeInterval(IntervalKind kind_) : kind(kind_) {}
|
||||
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
DataTypeInterval(Kind kind_) : kind(kind_) {}
|
||||
|
||||
std::string doGetName() const override { return std::string("Interval") + kindToString(); }
|
||||
std::string doGetName() const override { return std::string("Interval") + kind.toString(); }
|
||||
const char * getFamilyName() const override { return "Interval"; }
|
||||
TypeIndex getTypeId() const override { return TypeIndex::Interval; }
|
||||
|
||||
|
@ -894,7 +894,7 @@ MutableColumnUniquePtr DataTypeLowCardinality::createColumnUniqueImpl(const IDat
|
||||
if (isColumnedAsNumber(type))
|
||||
{
|
||||
MutableColumnUniquePtr column;
|
||||
TypeListNumbers::forEach(CreateColumnVector(column, *type, creator));
|
||||
TypeListNativeNumbers::forEach(CreateColumnVector(column, *type, creator));
|
||||
|
||||
if (!column)
|
||||
throw Exception("Unexpected numeric type: " + type->getName(), ErrorCodes::LOGICAL_ERROR);
|
||||
|
@ -58,7 +58,7 @@ bool DataTypeDecimal<T>::tryReadText(T & x, ReadBuffer & istr, UInt32 precision,
|
||||
{
|
||||
UInt32 unread_scale = scale;
|
||||
bool done = tryReadDecimalText(istr, x, precision, unread_scale);
|
||||
x *= getScaleMultiplier(unread_scale);
|
||||
x *= T::getScaleMultiplier(unread_scale);
|
||||
return done;
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ void DataTypeDecimal<T>::readText(T & x, ReadBuffer & istr, UInt32 precision, UI
|
||||
readCSVDecimalText(istr, x, precision, unread_scale);
|
||||
else
|
||||
readDecimalText(istr, x, precision, unread_scale);
|
||||
x *= getScaleMultiplier(unread_scale);
|
||||
x *= T::getScaleMultiplier(unread_scale);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -96,7 +96,7 @@ T DataTypeDecimal<T>::parseFromString(const String & str) const
|
||||
T x;
|
||||
UInt32 unread_scale = scale;
|
||||
readDecimalText(buf, x, precision, unread_scale, true);
|
||||
x *= getScaleMultiplier(unread_scale);
|
||||
x *= T::getScaleMultiplier(unread_scale);
|
||||
return x;
|
||||
}
|
||||
|
||||
@ -271,25 +271,6 @@ void registerDataTypeDecimal(DataTypeFactory & factory)
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
Decimal32 DataTypeDecimal<Decimal32>::getScaleMultiplier(UInt32 scale_)
|
||||
{
|
||||
return decimalScaleMultiplier<Int32>(scale_);
|
||||
}
|
||||
|
||||
template <>
|
||||
Decimal64 DataTypeDecimal<Decimal64>::getScaleMultiplier(UInt32 scale_)
|
||||
{
|
||||
return decimalScaleMultiplier<Int64>(scale_);
|
||||
}
|
||||
|
||||
template <>
|
||||
Decimal128 DataTypeDecimal<Decimal128>::getScaleMultiplier(UInt32 scale_)
|
||||
{
|
||||
return decimalScaleMultiplier<Int128>(scale_);
|
||||
}
|
||||
|
||||
|
||||
/// Explicit template instantiations.
|
||||
template class DataTypeDecimal<Decimal32>;
|
||||
template class DataTypeDecimal<Decimal64>;
|
||||
|
@ -130,7 +130,7 @@ public:
|
||||
|
||||
UInt32 getPrecision() const { return precision; }
|
||||
UInt32 getScale() const { return scale; }
|
||||
T getScaleMultiplier() const { return getScaleMultiplier(scale); }
|
||||
T getScaleMultiplier() const { return T::getScaleMultiplier(scale); }
|
||||
|
||||
T wholePart(T x) const
|
||||
{
|
||||
@ -148,7 +148,7 @@ public:
|
||||
return x % getScaleMultiplier();
|
||||
}
|
||||
|
||||
T maxWholeValue() const { return getScaleMultiplier(maxPrecision() - scale) - T(1); }
|
||||
T maxWholeValue() const { return T::getScaleMultiplier(maxPrecision() - scale) - T(1); }
|
||||
|
||||
bool canStoreWhole(T x) const
|
||||
{
|
||||
@ -165,7 +165,7 @@ public:
|
||||
if (getScale() < x.getScale())
|
||||
throw Exception("Decimal result's scale is less then argiment's one", ErrorCodes::ARGUMENT_OUT_OF_BOUND);
|
||||
UInt32 scale_delta = getScale() - x.getScale(); /// scale_delta >= 0
|
||||
return getScaleMultiplier(scale_delta);
|
||||
return T::getScaleMultiplier(scale_delta);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
@ -181,7 +181,6 @@ public:
|
||||
void readText(T & x, ReadBuffer & istr, bool csv = false) const { readText(x, istr, precision, scale, csv); }
|
||||
static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale, bool csv = false);
|
||||
static bool tryReadText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale);
|
||||
static T getScaleMultiplier(UInt32 scale);
|
||||
|
||||
private:
|
||||
const UInt32 precision;
|
||||
@ -264,12 +263,12 @@ convertDecimals(const typename FromDataType::FieldType & value, UInt32 scale_fro
|
||||
MaxNativeType converted_value;
|
||||
if (scale_to > scale_from)
|
||||
{
|
||||
converted_value = DataTypeDecimal<MaxFieldType>::getScaleMultiplier(scale_to - scale_from);
|
||||
converted_value = MaxFieldType::getScaleMultiplier(scale_to - scale_from);
|
||||
if (common::mulOverflow(static_cast<MaxNativeType>(value), converted_value, converted_value))
|
||||
throw Exception("Decimal convert overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
}
|
||||
else
|
||||
converted_value = value / DataTypeDecimal<MaxFieldType>::getScaleMultiplier(scale_from - scale_to);
|
||||
converted_value = value / MaxFieldType::getScaleMultiplier(scale_from - scale_to);
|
||||
|
||||
if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType))
|
||||
{
|
||||
@ -289,7 +288,7 @@ convertFromDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
|
||||
using ToFieldType = typename ToDataType::FieldType;
|
||||
|
||||
if constexpr (std::is_floating_point_v<ToFieldType>)
|
||||
return static_cast<ToFieldType>(value) / FromDataType::getScaleMultiplier(scale);
|
||||
return static_cast<ToFieldType>(value) / FromFieldType::getScaleMultiplier(scale);
|
||||
else
|
||||
{
|
||||
FromFieldType converted_value = convertDecimals<FromDataType, FromDataType>(value, scale, 0);
|
||||
@ -320,14 +319,15 @@ inline std::enable_if_t<IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDa
|
||||
convertToDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
|
||||
{
|
||||
using FromFieldType = typename FromDataType::FieldType;
|
||||
using ToNativeType = typename ToDataType::FieldType::NativeType;
|
||||
using ToFieldType = typename ToDataType::FieldType;
|
||||
using ToNativeType = typename ToFieldType::NativeType;
|
||||
|
||||
if constexpr (std::is_floating_point_v<FromFieldType>)
|
||||
{
|
||||
if (!std::isfinite(value))
|
||||
throw Exception("Decimal convert overflow. Cannot convert infinity or NaN to decimal", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
|
||||
auto out = value * ToDataType::getScaleMultiplier(scale);
|
||||
auto out = value * ToFieldType::getScaleMultiplier(scale);
|
||||
if constexpr (std::is_same_v<ToNativeType, Int128>)
|
||||
{
|
||||
static constexpr __int128 min_int128 = __int128(0x8000000000000000ll) << 64;
|
||||
|
@ -48,7 +48,7 @@ public:
|
||||
|
||||
double getLoadFactor() const override { return static_cast<double>(element_count.load(std::memory_order_relaxed)) / size; }
|
||||
|
||||
bool isCached() const override { return true; }
|
||||
bool supportUpdates() const override { return false; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
|
@ -3,8 +3,8 @@
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/ProfilingScopedRWLock.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <common/DateLUT.h>
|
||||
#include <DataStreams/IBlockInputStream.h>
|
||||
#include <ext/chrono_io.h>
|
||||
#include <ext/map.h>
|
||||
#include <ext/range.h>
|
||||
#include <ext/size.h>
|
||||
@ -334,7 +334,7 @@ void CacheDictionary::update(
|
||||
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
|
||||
|
||||
tryLogException(last_exception, log, "Could not update cache dictionary '" + getName() +
|
||||
"', next update is scheduled at " + DateLUT::instance().timeToString(std::chrono::system_clock::to_time_t(backoff_end_time)));
|
||||
"', next update is scheduled at " + ext::to_string(backoff_end_time));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ public:
|
||||
|
||||
double getLoadFactor() const override { return static_cast<double>(element_count.load(std::memory_order_relaxed)) / size; }
|
||||
|
||||
bool isCached() const override { return true; }
|
||||
bool supportUpdates() const override { return false; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
|
@ -46,8 +46,6 @@ public:
|
||||
|
||||
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_shared<ComplexKeyHashedDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty, saved_block);
|
||||
|
@ -43,8 +43,6 @@ public:
|
||||
|
||||
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_shared<FlatDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty, saved_block);
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <IO/ConnectionTimeouts.h>
|
||||
#include <IO/ReadWriteBufferFromHTTP.h>
|
||||
#include <IO/WriteBufferFromOStream.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Poco/Net/HTTPRequest.h>
|
||||
#include <common/logger_useful.h>
|
||||
@ -82,12 +84,9 @@ void HTTPDictionarySource::getUpdateFieldAndDate(Poco::URI & uri)
|
||||
auto tmp_time = update_time;
|
||||
update_time = std::chrono::system_clock::now();
|
||||
time_t hr_time = std::chrono::system_clock::to_time_t(tmp_time) - 1;
|
||||
char buffer[80];
|
||||
struct tm * timeinfo;
|
||||
timeinfo = localtime(&hr_time);
|
||||
strftime(buffer, 80, "%Y-%m-%d %H:%M:%S", timeinfo);
|
||||
std::string str_time(buffer);
|
||||
uri.addQueryParameter(update_field, str_time);
|
||||
WriteBufferFromOwnString out;
|
||||
writeDateTimeText(hr_time, out);
|
||||
uri.addQueryParameter(update_field, out.str());
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -48,8 +48,6 @@ public:
|
||||
|
||||
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_shared<HashedDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty, sparse, saved_block);
|
||||
|
@ -37,8 +37,6 @@ struct IDictionaryBase : public IExternalLoadable
|
||||
|
||||
virtual double getLoadFactor() const = 0;
|
||||
|
||||
virtual bool isCached() const = 0;
|
||||
|
||||
virtual const IDictionarySource * getSource() const = 0;
|
||||
|
||||
virtual const DictionaryStructure & getStructure() const = 0;
|
||||
@ -47,7 +45,7 @@ struct IDictionaryBase : public IExternalLoadable
|
||||
|
||||
virtual BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const = 0;
|
||||
|
||||
bool supportUpdates() const override { return !isCached(); }
|
||||
bool supportUpdates() const override { return true; }
|
||||
|
||||
bool isModified() const override
|
||||
{
|
||||
|
@ -38,8 +38,6 @@ public:
|
||||
|
||||
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_shared<RangeHashedDictionary>(dictionary_name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty);
|
||||
|
@ -47,8 +47,6 @@ public:
|
||||
|
||||
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_shared<TrieDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty);
|
||||
|
@ -176,7 +176,7 @@ void buildSingleAttribute(
|
||||
AutoPtr<Element> null_value_element(doc->createElement("null_value"));
|
||||
String null_value_str;
|
||||
if (dict_attr->default_value)
|
||||
null_value_str = queryToString(dict_attr->default_value);
|
||||
null_value_str = getUnescapedFieldString(dict_attr->default_value->as<ASTLiteral>()->value);
|
||||
AutoPtr<Text> null_value(doc->createTextNode(null_value_str));
|
||||
null_value_element->appendChild(null_value);
|
||||
attribute_element->appendChild(null_value_element);
|
||||
@ -184,7 +184,19 @@ void buildSingleAttribute(
|
||||
if (dict_attr->expression != nullptr)
|
||||
{
|
||||
AutoPtr<Element> expression_element(doc->createElement("expression"));
|
||||
AutoPtr<Text> expression(doc->createTextNode(queryToString(dict_attr->expression)));
|
||||
|
||||
/// EXPRESSION PROPERTY should be expression or string
|
||||
String expression_str;
|
||||
if (const auto * literal = dict_attr->expression->as<ASTLiteral>();
|
||||
literal && literal->value.getType() == Field::Types::String)
|
||||
{
|
||||
expression_str = getUnescapedFieldString(literal->value);
|
||||
}
|
||||
else
|
||||
expression_str = queryToString(dict_attr->expression);
|
||||
|
||||
|
||||
AutoPtr<Text> expression(doc->createTextNode(expression_str));
|
||||
expression_element->appendChild(expression);
|
||||
attribute_element->appendChild(expression_element);
|
||||
}
|
||||
|
@ -281,6 +281,8 @@ void registerInputFormatProcessorTSKV(FormatFactory & factory);
|
||||
void registerOutputFormatProcessorTSKV(FormatFactory & factory);
|
||||
void registerInputFormatProcessorJSONEachRow(FormatFactory & factory);
|
||||
void registerOutputFormatProcessorJSONEachRow(FormatFactory & factory);
|
||||
void registerInputFormatProcessorJSONCompactEachRow(FormatFactory & factory);
|
||||
void registerOutputFormatProcessorJSONCompactEachRow(FormatFactory & factory);
|
||||
void registerInputFormatProcessorParquet(FormatFactory & factory);
|
||||
void registerInputFormatProcessorORC(FormatFactory & factory);
|
||||
void registerOutputFormatProcessorParquet(FormatFactory & factory);
|
||||
@ -336,6 +338,8 @@ FormatFactory::FormatFactory()
|
||||
registerOutputFormatProcessorTSKV(*this);
|
||||
registerInputFormatProcessorJSONEachRow(*this);
|
||||
registerOutputFormatProcessorJSONEachRow(*this);
|
||||
registerInputFormatProcessorJSONCompactEachRow(*this);
|
||||
registerOutputFormatProcessorJSONCompactEachRow(*this);
|
||||
registerInputFormatProcessorProtobuf(*this);
|
||||
registerOutputFormatProcessorProtobuf(*this);
|
||||
registerInputFormatProcessorCapnProto(*this);
|
||||
|
@ -508,7 +508,7 @@ class FunctionBinaryArithmetic : public IFunction
|
||||
}
|
||||
|
||||
std::stringstream function_name;
|
||||
function_name << (function_is_plus ? "add" : "subtract") << interval_data_type->kindToString() << 's';
|
||||
function_name << (function_is_plus ? "add" : "subtract") << interval_data_type->getKind().toString() << 's';
|
||||
|
||||
return FunctionFactory::instance().get(function_name.str(), context);
|
||||
}
|
||||
|
@ -735,7 +735,7 @@ struct NameToDecimal128 { static constexpr auto name = "toDecimal128"; };
|
||||
struct NameToInterval ## INTERVAL_KIND \
|
||||
{ \
|
||||
static constexpr auto name = "toInterval" #INTERVAL_KIND; \
|
||||
static constexpr int kind = DataTypeInterval::INTERVAL_KIND; \
|
||||
static constexpr auto kind = IntervalKind::INTERVAL_KIND; \
|
||||
};
|
||||
|
||||
DEFINE_NAME_TO_INTERVAL(Second)
|
||||
@ -786,7 +786,7 @@ public:
|
||||
|
||||
if constexpr (std::is_same_v<ToDataType, DataTypeInterval>)
|
||||
{
|
||||
return std::make_shared<DataTypeInterval>(DataTypeInterval::Kind(Name::kind));
|
||||
return std::make_shared<DataTypeInterval>(Name::kind);
|
||||
}
|
||||
else if constexpr (to_decimal)
|
||||
{
|
||||
|
@ -707,6 +707,20 @@ private:
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
template <bool first>
|
||||
void executeGeneric(const IColumn * column, typename ColumnVector<ToType>::Container & vec_to)
|
||||
{
|
||||
for (size_t i = 0, size = column->size(); i < size; ++i)
|
||||
{
|
||||
StringRef bytes = column->getDataAt(i);
|
||||
const ToType h = Impl::apply(bytes.data, bytes.size);
|
||||
if (first)
|
||||
vec_to[i] = h;
|
||||
else
|
||||
vec_to[i] = Impl::combineHashes(vec_to[i], h);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool first>
|
||||
void executeString(const IColumn * column, typename ColumnVector<ToType>::Container & vec_to)
|
||||
{
|
||||
@ -843,8 +857,7 @@ private:
|
||||
else if (which.isFixedString()) executeString<first>(icolumn, vec_to);
|
||||
else if (which.isArray()) executeArray<first>(from_type, icolumn, vec_to);
|
||||
else
|
||||
throw Exception("Unexpected type " + from_type->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
executeGeneric<first>(icolumn, vec_to);
|
||||
}
|
||||
|
||||
void executeForArgument(const IDataType * type, const IColumn * column, typename ColumnVector<ToType>::Container & vec_to, bool & is_first)
|
||||
|
@ -20,6 +20,7 @@ void registerFunctionsJSON(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionJSON<NameJSONExtract, JSONExtractImpl>>();
|
||||
factory.registerFunction<FunctionJSON<NameJSONExtractKeysAndValues, JSONExtractKeysAndValuesImpl>>();
|
||||
factory.registerFunction<FunctionJSON<NameJSONExtractRaw, JSONExtractRawImpl>>();
|
||||
factory.registerFunction<FunctionJSON<NameJSONExtractArrayRaw, JSONExtractArrayRawImpl>>();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -291,6 +291,7 @@ struct NameJSONExtractString { static constexpr auto name{"JSONExtractString"};
|
||||
struct NameJSONExtract { static constexpr auto name{"JSONExtract"}; };
|
||||
struct NameJSONExtractKeysAndValues { static constexpr auto name{"JSONExtractKeysAndValues"}; };
|
||||
struct NameJSONExtractRaw { static constexpr auto name{"JSONExtractRaw"}; };
|
||||
struct NameJSONExtractArrayRaw { static constexpr auto name{"JSONExtractArrayRaw"}; };
|
||||
|
||||
|
||||
template <typename JSONParser>
|
||||
@ -1088,4 +1089,39 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename JSONParser>
|
||||
class JSONExtractArrayRawImpl
|
||||
{
|
||||
public:
|
||||
static DataTypePtr getType(const char *, const ColumnsWithTypeAndName &)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>());
|
||||
}
|
||||
|
||||
using Iterator = typename JSONParser::Iterator;
|
||||
static bool addValueToColumn(IColumn & dest, const Iterator & it)
|
||||
{
|
||||
if (!JSONParser::isArray(it))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
ColumnArray & col_res = assert_cast<ColumnArray &>(dest);
|
||||
Iterator array_it = it;
|
||||
size_t size = 0;
|
||||
if (JSONParser::firstArrayElement(array_it))
|
||||
{
|
||||
do
|
||||
{
|
||||
JSONExtractRawImpl<JSONParser>::addValueToColumn(col_res.getData(), array_it);
|
||||
++size;
|
||||
} while (JSONParser::nextArrayElement(array_it));
|
||||
}
|
||||
|
||||
col_res.getOffsets().push_back(col_res.getOffsets().back() + size);
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr size_t num_extra_arguments = 0;
|
||||
static void prepare(const char *, const Block &, const ColumnNumbers &, size_t) {}
|
||||
};
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include "Sources.h"
|
||||
#include "Sinks.h"
|
||||
@ -79,8 +80,16 @@ inline ALWAYS_INLINE void writeSlice(const NumericArraySlice<T> & slice, Generic
|
||||
{
|
||||
for (size_t i = 0; i < slice.size; ++i)
|
||||
{
|
||||
Field field = T(slice.data[i]);
|
||||
sink.elements.insert(field);
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
{
|
||||
DecimalField field(T(slice.data[i]), 0); /// TODO: Decimal scale
|
||||
sink.elements.insert(field);
|
||||
}
|
||||
else
|
||||
{
|
||||
Field field = T(slice.data[i]);
|
||||
sink.elements.insert(field);
|
||||
}
|
||||
}
|
||||
sink.current_offset += slice.size;
|
||||
}
|
||||
@ -422,9 +431,18 @@ bool sliceHasImpl(const FirstSliceType & first, const SecondSliceType & second,
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
bool sliceEqualElements(const NumericArraySlice<T> & first, const NumericArraySlice<U> & second, size_t first_ind, size_t second_ind)
|
||||
bool sliceEqualElements(const NumericArraySlice<T> & first [[maybe_unused]],
|
||||
const NumericArraySlice<U> & second [[maybe_unused]],
|
||||
size_t first_ind [[maybe_unused]],
|
||||
size_t second_ind [[maybe_unused]])
|
||||
{
|
||||
return accurate::equalsOp(first.data[first_ind], second.data[second_ind]);
|
||||
/// TODO: Decimal scale
|
||||
if constexpr (IsDecimalNumber<T> && IsDecimalNumber<U>)
|
||||
return accurate::equalsOp(typename T::NativeType(first.data[first_ind]), typename U::NativeType(second.data[second_ind]));
|
||||
else if constexpr (IsDecimalNumber<T> || IsDecimalNumber<U>)
|
||||
return false;
|
||||
else
|
||||
return accurate::equalsOp(first.data[first_ind], second.data[second_ind]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user