Merge branch 'master' into database_atomic

This commit is contained in:
Alexander Tokmakov 2019-12-11 23:05:53 +03:00
commit 4d23c5e4d4
695 changed files with 14914 additions and 4228 deletions

View File

@ -198,11 +198,11 @@ if(WITH_COVERAGE AND COMPILER_GCC)
endif()
set (CMAKE_BUILD_COLOR_MAKEFILE ON)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COMPILER_FLAGS} ${PLATFORM_EXTRA_CXX_FLAG} -fno-omit-frame-pointer ${COMMON_WARNING_FLAGS} ${CXX_WARNING_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COMPILER_FLAGS} ${PLATFORM_EXTRA_CXX_FLAG} ${COMMON_WARNING_FLAGS} ${CXX_WARNING_FLAGS}")
set (CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3 ${CMAKE_CXX_FLAGS_ADD}")
set (CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb3 -fno-inline ${CMAKE_CXX_FLAGS_ADD}")
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${COMPILER_FLAGS} -fno-omit-frame-pointer ${COMMON_WARNING_FLAGS} ${CMAKE_C_FLAGS_ADD}")
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${COMPILER_FLAGS} ${COMMON_WARNING_FLAGS} ${CMAKE_C_FLAGS_ADD}")
set (CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} -O3 ${CMAKE_C_FLAGS_ADD}")
set (CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O0 -g3 -ggdb3 -fno-inline ${CMAKE_C_FLAGS_ADD}")
@ -382,16 +382,12 @@ add_subdirectory (contrib EXCLUDE_FROM_ALL)
macro (add_executable target)
# invoke built-in add_executable
_add_executable (${ARGV})
# explicitly acquire and interpose malloc symbols by clickhouse_malloc
_add_executable (${ARGV} $<TARGET_OBJECTS:clickhouse_malloc>)
get_target_property (type ${target} TYPE)
if (${type} STREQUAL EXECUTABLE)
file (RELATIVE_PATH dir ${CMAKE_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
if (${dir} MATCHES "^dbms")
# Only interpose operator::new/delete for dbms executables (MemoryTracker stuff)
target_link_libraries (${target} PRIVATE clickhouse_new_delete ${MALLOC_LIBRARIES})
else ()
target_link_libraries (${target} PRIVATE ${MALLOC_LIBRARIES})
endif ()
# operator::new/delete for executables (MemoryTracker stuff)
target_link_libraries (${target} PRIVATE clickhouse_new_delete ${MALLOC_LIBRARIES})
endif()
endmacro()

View File

@ -14,5 +14,4 @@ ClickHouse is an open-source column-oriented database management system that all
## Upcoming Events
* [ClickHouse Meetup in San Francisco](https://www.eventbrite.com/e/clickhouse-december-meetup-registration-78642047481) on December 3.
* [ClickHouse Meetup in Moscow](https://yandex.ru/promo/clickhouse/moscow-december-2019) on December 11.

View File

@ -1,5 +1,8 @@
find_library (TERMCAP_LIBRARY termcap)
find_library (TERMCAP_LIBRARY tinfo)
if (NOT TERMCAP_LIBRARY)
find_library (TERMCAP_LIBRARY tinfo)
find_library (TERMCAP_LIBRARY ncurses)
endif()
if (NOT TERMCAP_LIBRARY)
find_library (TERMCAP_LIBRARY termcap)
endif()
message (STATUS "Using termcap: ${TERMCAP_LIBRARY}")

View File

@ -20,16 +20,38 @@ else ()
message (WARNING "You are using an unsupported compiler. Compilation has only been tested with Clang 6+ and GCC 7+.")
endif ()
STRING(REGEX MATCHALL "[0-9]+" COMPILER_VERSION_LIST ${CMAKE_CXX_COMPILER_VERSION})
LIST(GET COMPILER_VERSION_LIST 0 COMPILER_VERSION_MAJOR)
option (LINKER_NAME "Linker name or full path")
if (COMPILER_GCC)
find_program (LLD_PATH NAMES "ld.lld")
find_program (GOLD_PATH NAMES "ld.gold")
else ()
find_program (LLD_PATH NAMES "ld.lld-${COMPILER_VERSION_MAJOR}" "lld-${COMPILER_VERSION_MAJOR}" "ld.lld" "lld")
find_program (GOLD_PATH NAMES "ld.gold" "gold")
endif ()
find_program (LLD_PATH NAMES "ld.lld" "lld")
find_program (GOLD_PATH NAMES "ld.gold" "gold")
# We prefer LLD linker over Gold or BFD.
if (NOT LINKER_NAME)
if (LLD_PATH)
set (LINKER_NAME "lld")
elseif (GOLD_PATH)
set (LINKER_NAME "gold")
if (COMPILER_GCC)
# GCC driver requires one of supported linker names like "lld".
set (LINKER_NAME "lld")
else ()
# Clang driver simply allows full linker path.
set (LINKER_NAME ${LLD_PATH})
endif ()
endif ()
endif ()
if (NOT LINKER_NAME)
if (GOLD_PATH)
if (COMPILER_GCC)
set (LINKER_NAME "gold")
else ()
set (LINKER_NAME ${GOLD_PATH})
endif ()
endif ()
endif ()

View File

@ -52,6 +52,7 @@ if (USE_INTERNAL_BTRIE_LIBRARY)
endif ()
if (USE_INTERNAL_ZLIB_LIBRARY)
unset (BUILD_SHARED_LIBS CACHE)
set (ZLIB_ENABLE_TESTS 0 CACHE INTERNAL "")
set (SKIP_INSTALL_ALL 1 CACHE INTERNAL "")
set (ZLIB_COMPAT 1 CACHE INTERNAL "") # also enables WITH_GZFILEOP

2
contrib/poco vendored

@ -1 +1 @@
Subproject commit 2b273bfe9db89429b2040c024484dee0197e48c7
Subproject commit d478f62bd93c9cd14eb343756ef73a4ae622ddf5

2
contrib/zlib-ng vendored

@ -1 +1 @@
Subproject commit cff0f500d9399d7cd3b9461a693d211e4b86fcc9
Subproject commit bba56a73be249514acfbc7d49aa2a68994dad8ab

View File

@ -100,7 +100,7 @@ set(dbms_sources)
add_headers_and_sources(clickhouse_common_io src/Common)
add_headers_and_sources(clickhouse_common_io src/Common/HashTable)
add_headers_and_sources(clickhouse_common_io src/IO)
list (REMOVE_ITEM clickhouse_common_io_sources src/Common/new_delete.cpp)
list (REMOVE_ITEM clickhouse_common_io_sources src/Common/malloc.cpp src/Common/new_delete.cpp)
if(USE_RDKAFKA)
add_headers_and_sources(dbms src/Storages/Kafka)
@ -140,6 +140,9 @@ endif ()
add_library(clickhouse_common_io ${clickhouse_common_io_headers} ${clickhouse_common_io_sources})
add_library (clickhouse_malloc OBJECT src/Common/malloc.cpp)
set_source_files_properties(src/Common/malloc.cpp PROPERTIES COMPILE_FLAGS "-fno-builtin")
add_library (clickhouse_new_delete STATIC src/Common/new_delete.cpp)
target_link_libraries (clickhouse_new_delete PRIVATE clickhouse_common_io)
@ -376,6 +379,10 @@ if (USE_POCO_MONGODB)
dbms_target_link_libraries (PRIVATE ${Poco_MongoDB_LIBRARY})
endif()
if (USE_POCO_REDIS)
dbms_target_link_libraries (PRIVATE ${Poco_Redis_LIBRARY})
endif()
if (USE_POCO_NETSSL)
target_link_libraries (clickhouse_common_io PRIVATE ${Poco_NetSSL_LIBRARY} ${Poco_Crypto_LIBRARY})
dbms_target_link_libraries (PRIVATE ${Poco_NetSSL_LIBRARY} ${Poco_Crypto_LIBRARY})
@ -428,6 +435,8 @@ if (USE_JEMALLOC)
if(NOT MAKE_STATIC_LIBRARIES AND ${JEMALLOC_LIBRARIES} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$")
# mallctl in dbms/src/Interpreters/AsynchronousMetrics.cpp
# Actually we link JEMALLOC to almost all libraries.
# This is just hotfix for some uninvestigated problem.
target_link_libraries(clickhouse_interpreters PRIVATE ${JEMALLOC_LIBRARIES})
endif()
endif ()

View File

@ -1,11 +1,11 @@
# This strings autochanged from release_lib.sh:
set(VERSION_REVISION 54429)
set(VERSION_REVISION 54430)
set(VERSION_MAJOR 19)
set(VERSION_MINOR 18)
set(VERSION_MINOR 19)
set(VERSION_PATCH 1)
set(VERSION_GITHASH 4e68211879480b637683ae66dbcc89a2714682af)
set(VERSION_DESCRIBE v19.18.1.1-prestable)
set(VERSION_STRING 19.18.1.1)
set(VERSION_GITHASH 8bd9709d1dec3366e35d2efeab213435857f67a9)
set(VERSION_DESCRIBE v19.19.1.1-prestable)
set(VERSION_STRING 19.19.1.1)
# end of autochange
set(VERSION_EXTRA "" CACHE STRING "")

View File

@ -34,7 +34,6 @@
#include <IO/WriteBufferFromTemporaryFile.h>
#include <DataStreams/IBlockInputStream.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/Quota.h>
#include <Common/typeid_cast.h>
#include <Poco/Net/HTTPStream.h>

View File

@ -15,6 +15,7 @@
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <Storages/IStorage.h>
#include <boost/algorithm/string/replace.hpp>
#if USE_POCO_NETSSL
#include <Poco/Net/SecureStreamSocket.h>
@ -220,7 +221,8 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
{
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
auto user = connection_context.getUser(user_name);
if (user->authentication.getType() != DB::Authentication::DOUBLE_SHA1_PASSWORD)
const DB::Authentication::Type user_auth_type = user->authentication.getType();
if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD)
{
authPluginSSL();
}
@ -267,29 +269,49 @@ void MySQLHandler::comPing()
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
}
static bool isFederatedServerSetupCommand(const String & query);
void MySQLHandler::comQuery(ReadBuffer & payload)
{
bool with_output = false;
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
with_output = true;
};
String query = String(payload.position(), payload.buffer().end());
const String query("select ''");
ReadBufferFromString empty_select(query);
bool should_replace = false;
// Translate query from MySQL to ClickHouse.
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
if (std::string(payload.position(), payload.buffer().end()) == "select @@version_comment limit 1") // MariaDB client starts session with that query
// This is a workaround in order to support adding ClickHouse to MySQL using federated server.
// As Clickhouse doesn't support these statements, we just send OK packet in response.
if (isFederatedServerSetupCommand(query))
{
should_replace = true;
}
Context query_context = connection_context;
executeQuery(should_replace ? empty_select : payload, *out, true, query_context, set_content_type, nullptr);
if (!with_output)
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
}
else
{
bool with_output = false;
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
with_output = true;
};
String replacement_query = "select ''";
bool should_replace = false;
// Translate query from MySQL to ClickHouse.
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query
{
should_replace = true;
}
// This is a workaround in order to support adding ClickHouse to MySQL using federated server.
if (0 == strncasecmp("SHOW TABLE STATUS LIKE", query.c_str(), 22))
{
should_replace = true;
replacement_query = boost::replace_all_copy(query, "SHOW TABLE STATUS LIKE ", show_table_status_replacement_query);
}
ReadBufferFromString replacement(replacement_query);
Context query_context = connection_context;
executeQuery(should_replace ? replacement : payload, *out, true, query_context, set_content_type, nullptr);
if (!with_output)
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
}
}
void MySQLHandler::authPluginSSL()
@ -335,4 +357,33 @@ void MySQLHandlerSSL::finishHandshakeSSL(size_t packet_size, char * buf, size_t
#endif
static bool isFederatedServerSetupCommand(const String & query)
{
return 0 == strncasecmp("SET NAMES", query.c_str(), 9) || 0 == strncasecmp("SET character_set_results", query.c_str(), 25)
|| 0 == strncasecmp("SET FOREIGN_KEY_CHECKS", query.c_str(), 22) || 0 == strncasecmp("SET AUTOCOMMIT", query.c_str(), 14)
|| 0 == strncasecmp("SET SESSION TRANSACTION ISOLATION LEVEL", query.c_str(), 39);
}
const String MySQLHandler::show_table_status_replacement_query("SELECT"
" name AS Name,"
" engine AS Engine,"
" '10' AS Version,"
" 'Dynamic' AS Row_format,"
" 0 AS Rows,"
" 0 AS Avg_row_length,"
" 0 AS Data_length,"
" 0 AS Max_data_length,"
" 0 AS Index_length,"
" 0 AS Data_free,"
" 'NULL' AS Auto_increment,"
" metadata_modification_time AS Create_time,"
" metadata_modification_time AS Update_time,"
" metadata_modification_time AS Check_time,"
" 'utf8_bin' AS Collation,"
" 'NULL' AS Checksum,"
" '' AS Create_options,"
" '' AS Comment"
" FROM system.tables"
" WHERE name LIKE ");
}

View File

@ -11,7 +11,6 @@
namespace DB
{
/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
class MySQLHandler : public Poco::Net::TCPServerConnection
{
@ -59,6 +58,9 @@ protected:
std::shared_ptr<WriteBuffer> out;
bool secure_connection = false;
private:
static const String show_table_status_replacement_query;
};
#if USE_SSL && USE_POCO_NETSSL

View File

@ -243,6 +243,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
}
#endif
global_context->setRemoteHostFilter(config());
std::string path = getCanonicalPath(config().getString("path", DBMS_DEFAULT_PATH));
std::string default_database = config().getString("default_database", "default");

View File

@ -19,7 +19,6 @@
#include <DataStreams/NativeBlockInputStream.h>
#include <DataStreams/NativeBlockOutputStream.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/Quota.h>
#include <Interpreters/TablesStatus.h>
#include <Interpreters/InternalTextLogsQueue.h>
#include <Storages/StorageMemory.h>
@ -201,6 +200,8 @@ void TCPHandler::runImpl()
/// So, the stream has been marked as cancelled and we can't read from it anymore.
state.block_in.reset();
state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker.
state.temporary_tables_read = true;
});
/// Send structure of columns to client for function input()
@ -340,6 +341,18 @@ void TCPHandler::runImpl()
LOG_WARNING(log, "Client has gone away.");
}
try
{
if (exception && !state.temporary_tables_read)
query_context->initializeExternalTablesIfSet();
}
catch (...)
{
network_error = true;
LOG_WARNING(log, "Can't read external tables after query failure.");
}
try
{
query_scope.reset();

View File

@ -63,6 +63,8 @@ struct QueryState
bool sent_all_data = false;
/// Request requires data from the client (INSERT, but not INSERT SELECT).
bool need_receive_data_for_insert = false;
/// Temporary tables read
bool temporary_tables_read = false;
/// Request requires data from client for function input()
bool need_receive_data_for_input = false;

View File

@ -3,6 +3,25 @@
NOTE: User and query level settings are set up in "users.xml" file.
-->
<yandex>
<!-- The list of hosts allowed to use in URL-related storage engines and table functions.
If this section is not present in configuration, all hosts are allowed.
-->
<remote_url_allow_hosts>
<!-- Host should be specified exactly as in URL. The name is checked before DNS resolution.
Example: "yandex.ru", "yandex.ru." and "www.yandex.ru" are different hosts.
If port is explicitly specified in URL, the host:port is checked as a whole.
If host specified here without port, any port with this host allowed.
"yandex.ru" -> "yandex.ru:443", "yandex.ru:80" etc. is allowed, but "yandex.ru:80" -> only "yandex.ru:80" is allowed.
If the host is specified as IP address, it is checked as specified in URL. Example: "[2a02:6b8:a::a]".
If there are redirects and support for redirects is enabled, every redirect (the Location field) is checked.
-->
<!-- Regular expression can be specified. RE2 engine is used for regexps.
Regexps are not aligned: don't forget to add ^ and $. Also don't forget to escape dot (.) metacharacter
(forgetting to do so is a common source of error).
-->
</remote_url_allow_hosts>
<logger>
<!-- Possible levels: https://github.com/pocoproject/poco/blob/develop/Foundation/include/Poco/Logger.h#L105 -->
<level>trace</level>
@ -15,7 +34,6 @@
<!--display_name>production</display_name--> <!-- It is the name that will be shown in the client -->
<http_port>8123</http_port>
<tcp_port>9000</tcp_port>
<!-- For HTTPS and SSL over native protocol. -->
<!--
<https_port>8443</https_port>

View File

@ -0,0 +1,52 @@
#include <Access/AccessControlManager.h>
#include <Access/MultipleAccessStorage.h>
#include <Access/MemoryAccessStorage.h>
#include <Access/UsersConfigAccessStorage.h>
#include <Access/QuotaContextFactory.h>
namespace DB
{
namespace
{
std::vector<std::unique_ptr<IAccessStorage>> createStorages()
{
std::vector<std::unique_ptr<IAccessStorage>> list;
list.emplace_back(std::make_unique<MemoryAccessStorage>());
list.emplace_back(std::make_unique<UsersConfigAccessStorage>());
return list;
}
}
AccessControlManager::AccessControlManager()
: MultipleAccessStorage(createStorages()),
quota_context_factory(std::make_unique<QuotaContextFactory>(*this))
{
}
AccessControlManager::~AccessControlManager()
{
}
void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguration & users_config)
{
auto & users_config_access_storage = dynamic_cast<UsersConfigAccessStorage &>(getStorageByIndex(1));
users_config_access_storage.loadFromConfig(users_config);
}
std::shared_ptr<QuotaContext> AccessControlManager::createQuotaContext(
const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key)
{
return quota_context_factory->createContext(user_name, address, custom_quota_key);
}
std::vector<QuotaUsageInfo> AccessControlManager::getQuotaUsageInfo() const
{
return quota_context_factory->getUsageInfo();
}
}

View File

@ -0,0 +1,45 @@
#pragma once
#include <Access/MultipleAccessStorage.h>
#include <Poco/AutoPtr.h>
#include <memory>
namespace Poco
{
namespace Net
{
class IPAddress;
}
namespace Util
{
class AbstractConfiguration;
}
}
namespace DB
{
class QuotaContext;
class QuotaContextFactory;
struct QuotaUsageInfo;
/// Manages access control entities.
class AccessControlManager : public MultipleAccessStorage
{
public:
AccessControlManager();
~AccessControlManager();
void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config);
std::shared_ptr<QuotaContext>
createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key);
std::vector<QuotaUsageInfo> getQuotaUsageInfo() const;
private:
std::unique_ptr<QuotaContextFactory> quota_context_factory;
};
}

View File

@ -160,6 +160,35 @@ void Authentication::setPasswordHashBinary(const Digest & hash)
}
Digest Authentication::getPasswordDoubleSHA1() const
{
switch (type)
{
case NO_PASSWORD:
{
Poco::SHA1Engine engine;
return engine.digest();
}
case PLAINTEXT_PASSWORD:
{
Poco::SHA1Engine engine;
engine.update(getPassword());
const Digest & first_sha1 = engine.digest();
engine.update(first_sha1.data(), first_sha1.size());
return engine.digest();
}
case SHA256_PASSWORD:
throw Exception("Cannot get password double SHA1 for user with 'SHA256_PASSWORD' authentication.", ErrorCodes::BAD_ARGUMENTS);
case DOUBLE_SHA1_PASSWORD:
return password_hash;
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
bool Authentication::isCorrectPassword(const String & password_) const
{
switch (type)
@ -168,7 +197,14 @@ bool Authentication::isCorrectPassword(const String & password_) const
return true;
case PLAINTEXT_PASSWORD:
return password_ == StringRef{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()};
{
if (password_ == StringRef{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
return true;
// For compatibility with MySQL clients which support only native authentication plugin, SHA1 can be passed instead of password.
auto password_sha1 = encodeSHA1(password_hash);
return password_ == StringRef{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
}
case SHA256_PASSWORD:
return encodeSHA256(password_) == password_hash;

View File

@ -49,6 +49,10 @@ public:
void setPasswordHashBinary(const Digest & hash);
const Digest & getPasswordHashBinary() const { return password_hash; }
/// Returns SHA1(SHA1(password)) used by MySQL compatibility server for authentication.
/// Allowed to use for Type::NO_PASSWORD, Type::PLAINTEXT_PASSWORD, Type::DOUBLE_SHA1_PASSWORD.
Digest getPasswordDoubleSHA1() const;
/// Checks if the provided password is correct. Returns false if not.
bool isCorrectPassword(const String & password) const;

View File

@ -0,0 +1,19 @@
#include <Access/IAccessEntity.h>
#include <Access/Quota.h>
#include <common/demangle.h>
namespace DB
{
String IAccessEntity::getTypeName(std::type_index type)
{
if (type == typeid(Quota))
return "Quota";
return demangle(type.name());
}
bool IAccessEntity::equal(const IAccessEntity & other) const
{
return (full_name == other.full_name) && (getType() == other.getType());
}
}

View File

@ -0,0 +1,49 @@
#pragma once
#include <Core/Types.h>
#include <Common/typeid_cast.h>
#include <memory>
#include <typeindex>
namespace DB
{
/// Access entity is a set of data which have a name and a type. Access entity control something related to the access control.
/// Entities can be stored to a file or another storage, see IAccessStorage.
struct IAccessEntity
{
IAccessEntity() = default;
IAccessEntity(const IAccessEntity &) = default;
virtual ~IAccessEntity() = default;
virtual std::shared_ptr<IAccessEntity> clone() const = 0;
std::type_index getType() const { return typeid(*this); }
static String getTypeName(std::type_index type);
const String getTypeName() const { return getTypeName(getType()); }
template <typename EntityType>
bool isTypeOf() const { return isTypeOf(typeid(EntityType)); }
bool isTypeOf(std::type_index type) const { return type == getType(); }
virtual void setName(const String & name_) { full_name = name_; }
virtual String getName() const { return full_name; }
String getFullName() const { return full_name; }
friend bool operator ==(const IAccessEntity & lhs, const IAccessEntity & rhs) { return lhs.equal(rhs); }
friend bool operator !=(const IAccessEntity & lhs, const IAccessEntity & rhs) { return !(lhs == rhs); }
protected:
String full_name;
virtual bool equal(const IAccessEntity & other) const;
/// Helper function to define clone() in the derived classes.
template <typename EntityType>
std::shared_ptr<IAccessEntity> cloneImpl() const
{
return std::make_shared<EntityType>(typeid_cast<const EntityType &>(*this));
}
};
using AccessEntityPtr = std::shared_ptr<const IAccessEntity>;
}

View File

@ -0,0 +1,450 @@
#include <Access/IAccessStorage.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <IO/WriteHelpers.h>
#include <Poco/UUIDGenerator.h>
#include <Poco/Logger.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_CAST;
extern const int ACCESS_ENTITY_NOT_FOUND;
extern const int ACCESS_ENTITY_ALREADY_EXISTS;
extern const int ACCESS_ENTITY_FOUND_DUPLICATES;
extern const int ACCESS_ENTITY_STORAGE_READONLY;
}
std::vector<UUID> IAccessStorage::findAll(std::type_index type) const
{
return findAllImpl(type);
}
std::optional<UUID> IAccessStorage::find(std::type_index type, const String & name) const
{
return findImpl(type, name);
}
std::vector<UUID> IAccessStorage::find(std::type_index type, const Strings & names) const
{
std::vector<UUID> ids;
ids.reserve(names.size());
for (const String & name : names)
{
auto id = findImpl(type, name);
if (id)
ids.push_back(*id);
}
return ids;
}
UUID IAccessStorage::getID(std::type_index type, const String & name) const
{
auto id = findImpl(type, name);
if (id)
return *id;
throwNotFound(type, name);
}
std::vector<UUID> IAccessStorage::getIDs(std::type_index type, const Strings & names) const
{
std::vector<UUID> ids;
ids.reserve(names.size());
for (const String & name : names)
ids.push_back(getID(type, name));
return ids;
}
bool IAccessStorage::exists(const UUID & id) const
{
return existsImpl(id);
}
AccessEntityPtr IAccessStorage::tryReadBase(const UUID & id) const
{
try
{
return readImpl(id);
}
catch (Exception &)
{
return nullptr;
}
}
String IAccessStorage::readName(const UUID & id) const
{
return readNameImpl(id);
}
std::optional<String> IAccessStorage::tryReadName(const UUID & id) const
{
try
{
return readNameImpl(id);
}
catch (Exception &)
{
return {};
}
}
UUID IAccessStorage::insert(const AccessEntityPtr & entity)
{
return insertImpl(entity, false);
}
std::vector<UUID> IAccessStorage::insert(const std::vector<AccessEntityPtr> & multiple_entities)
{
std::vector<UUID> ids;
ids.reserve(multiple_entities.size());
String error_message;
for (const auto & entity : multiple_entities)
{
try
{
ids.push_back(insertImpl(entity, false));
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS)
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
}
}
if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
return ids;
}
std::optional<UUID> IAccessStorage::tryInsert(const AccessEntityPtr & entity)
{
try
{
return insertImpl(entity, false);
}
catch (Exception &)
{
return {};
}
}
std::vector<UUID> IAccessStorage::tryInsert(const std::vector<AccessEntityPtr> & multiple_entities)
{
std::vector<UUID> ids;
ids.reserve(multiple_entities.size());
for (const auto & entity : multiple_entities)
{
try
{
ids.push_back(insertImpl(entity, false));
}
catch (Exception &)
{
}
}
return ids;
}
UUID IAccessStorage::insertOrReplace(const AccessEntityPtr & entity)
{
return insertImpl(entity, true);
}
std::vector<UUID> IAccessStorage::insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities)
{
std::vector<UUID> ids;
ids.reserve(multiple_entities.size());
for (const auto & entity : multiple_entities)
ids.push_back(insertImpl(entity, true));
return ids;
}
void IAccessStorage::remove(const UUID & id)
{
removeImpl(id);
}
void IAccessStorage::remove(const std::vector<UUID> & ids)
{
String error_message;
for (const auto & id : ids)
{
try
{
removeImpl(id);
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND)
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
}
}
if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
bool IAccessStorage::tryRemove(const UUID & id)
{
try
{
removeImpl(id);
return true;
}
catch (Exception &)
{
return false;
}
}
std::vector<UUID> IAccessStorage::tryRemove(const std::vector<UUID> & ids)
{
std::vector<UUID> removed;
removed.reserve(ids.size());
for (const auto & id : ids)
{
try
{
removeImpl(id);
removed.push_back(id);
}
catch (Exception &)
{
}
}
return removed;
}
void IAccessStorage::update(const UUID & id, const UpdateFunc & update_func)
{
updateImpl(id, update_func);
}
void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & update_func)
{
String error_message;
for (const auto & id : ids)
{
try
{
updateImpl(id, update_func);
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND)
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
}
}
if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
bool IAccessStorage::tryUpdate(const UUID & id, const UpdateFunc & update_func)
{
try
{
updateImpl(id, update_func);
return true;
}
catch (Exception &)
{
return false;
}
}
std::vector<UUID> IAccessStorage::tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func)
{
std::vector<UUID> updated;
updated.reserve(ids.size());
for (const auto & id : ids)
{
try
{
updateImpl(id, update_func);
updated.push_back(id);
}
catch (Exception &)
{
}
}
return updated;
}
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(type, handler);
}
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(id, handler);
}
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
if (ids.empty())
return nullptr;
if (ids.size() == 1)
return subscribeForChangesImpl(ids[0], handler);
std::vector<SubscriptionPtr> subscriptions;
subscriptions.reserve(ids.size());
for (const auto & id : ids)
{
auto subscription = subscribeForChangesImpl(id, handler);
if (subscription)
subscriptions.push_back(std::move(subscription));
}
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(std::vector<SubscriptionPtr> subscriptions_)
: subscriptions(std::move(subscriptions_)) {}
private:
std::vector<SubscriptionPtr> subscriptions;
};
return std::make_unique<SubscriptionImpl>(std::move(subscriptions));
}
bool IAccessStorage::hasSubscription(std::type_index type) const
{
return hasSubscriptionImpl(type);
}
bool IAccessStorage::hasSubscription(const UUID & id) const
{
return hasSubscriptionImpl(id);
}
void IAccessStorage::notify(const Notifications & notifications)
{
for (const auto & [fn, id, new_entity] : notifications)
fn(id, new_entity);
}
UUID IAccessStorage::generateRandomID()
{
static Poco::UUIDGenerator generator;
UUID id;
generator.createRandom().copyTo(reinterpret_cast<char *>(&id));
return id;
}
Poco::Logger * IAccessStorage::getLogger() const
{
Poco::Logger * ptr = log.load();
if (!ptr)
log.store(ptr = &Poco::Logger::get("Access(" + storage_name + ")"), std::memory_order_relaxed);
return ptr;
}
void IAccessStorage::throwNotFound(const UUID & id) const
{
throw Exception("ID {" + toString(id) + "} not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
void IAccessStorage::throwNotFound(std::type_index type, const String & name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
void IAccessStorage::throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) const
{
throw Exception(
"ID {" + toString(id) + "}: " + getTypeName(type) + backQuote(name) + " expected to be of type " + getTypeName(required_type),
ErrorCodes::BAD_CAST);
}
void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(name) + ": cannot insert because the ID {" + toString(id) + "} is already used by "
+ getTypeName(existing_type) + " " + backQuote(existing_name) + " in " + getStorageName(),
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
}
void IAccessStorage::throwNameCollisionCannotInsert(std::type_index type, const String & name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(name) + ": cannot insert because " + getTypeName(type) + " " + backQuote(name)
+ " already exists in " + getStorageName(),
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
}
void IAccessStorage::throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(old_name) + ": cannot rename to " + backQuote(new_name) + " because " + getTypeName(type) + " "
+ backQuote(new_name) + " already exists in " + getStorageName(),
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
}
void IAccessStorage::throwReadonlyCannotInsert(std::type_index type, const String & name) const
{
throw Exception(
"Cannot insert " + getTypeName(type) + " " + backQuote(name) + " to " + getStorageName() + " because this storage is readonly",
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
}
void IAccessStorage::throwReadonlyCannotUpdate(std::type_index type, const String & name) const
{
throw Exception(
"Cannot update " + getTypeName(type) + " " + backQuote(name) + " in " + getStorageName() + " because this storage is readonly",
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
}
void IAccessStorage::throwReadonlyCannotRemove(std::type_index type, const String & name) const
{
throw Exception(
"Cannot remove " + getTypeName(type) + " " + backQuote(name) + " from " + getStorageName() + " because this storage is readonly",
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
}
}

View File

@ -0,0 +1,209 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <functional>
#include <optional>
#include <vector>
#include <atomic>
namespace Poco { class Logger; }
namespace DB
{
/// Contains entities, i.e. instances of classes derived from IAccessEntity.
/// The implementations of this class MUST be thread-safe.
class IAccessStorage
{
public:
IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {}
virtual ~IAccessStorage() {}
/// Returns the name of this storage.
const String & getStorageName() const { return storage_name; }
/// Returns the identifiers of all the entities of a specified type contained in the storage.
std::vector<UUID> findAll(std::type_index type) const;
template <typename EntityType>
std::vector<UUID> findAll() const { return findAll(typeid(EntityType)); }
/// Searchs for an entity with specified type and name. Returns std::nullopt if not found.
std::optional<UUID> find(std::type_index type, const String & name) const;
template <typename EntityType>
std::optional<UUID> find(const String & name) const { return find(typeid(EntityType), name); }
std::vector<UUID> find(std::type_index type, const Strings & names) const;
template <typename EntityType>
std::vector<UUID> find(const Strings & names) const { return find(typeid(EntityType), names); }
/// Searchs for an entity with specified name and type. Throws an exception if not found.
UUID getID(std::type_index type, const String & name) const;
template <typename EntityType>
UUID getID(const String & name) const { return getID(typeid(EntityType), name); }
std::vector<UUID> getIDs(std::type_index type, const Strings & names) const;
template <typename EntityType>
std::vector<UUID> getIDs(const Strings & names) const { return getIDs(typeid(EntityType), names); }
/// Returns whether there is an entity with such identifier in the storage.
bool exists(const UUID & id) const;
/// Reads an entity. Throws an exception if not found.
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> read(const UUID & id) const;
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> read(const String & name) const;
/// Reads an entity. Returns nullptr if not found.
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> tryRead(const UUID & id) const;
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> tryRead(const String & name) const;
/// Reads only name of an entity.
String readName(const UUID & id) const;
std::optional<String> tryReadName(const UUID & id) const;
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
/// Throws an exception if the specified name already exists.
UUID insert(const AccessEntityPtr & entity);
std::vector<UUID> insert(const std::vector<AccessEntityPtr> & multiple_entities);
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
std::optional<UUID> tryInsert(const AccessEntityPtr & entity);
std::vector<UUID> tryInsert(const std::vector<AccessEntityPtr> & multiple_entities);
/// Inserts an entity to the storage. Return ID of a new entry in the storage.
/// Replaces an existing entry in the storage if the specified name already exists.
UUID insertOrReplace(const AccessEntityPtr & entity);
std::vector<UUID> insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities);
/// Removes an entity from the storage. Throws an exception if couldn't remove.
void remove(const UUID & id);
void remove(const std::vector<UUID> & ids);
/// Removes an entity from the storage. Returns false if couldn't remove.
bool tryRemove(const UUID & id);
/// Removes multiple entities from the storage. Returns the list of successfully dropped.
std::vector<UUID> tryRemove(const std::vector<UUID> & ids);
using UpdateFunc = std::function<AccessEntityPtr(const AccessEntityPtr &)>;
/// Updates an entity stored in the storage. Throws an exception if couldn't update.
void update(const UUID & id, const UpdateFunc & update_func);
void update(const std::vector<UUID> & ids, const UpdateFunc & update_func);
/// Updates an entity stored in the storage. Returns false if couldn't update.
bool tryUpdate(const UUID & id, const UpdateFunc & update_func);
/// Updates multiple entities in the storage. Returns the list of successfully updated.
std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func);
class Subscription
{
public:
virtual ~Subscription() {}
};
using SubscriptionPtr = std::unique_ptr<Subscription>;
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
SubscriptionPtr subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const;
template <typename EntityType>
SubscriptionPtr subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(typeid(EntityType), handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
SubscriptionPtr subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
SubscriptionPtr subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
bool hasSubscription(std::type_index type) const;
bool hasSubscription(const UUID & id) const;
protected:
virtual std::optional<UUID> findImpl(std::type_index type, const String & name) const = 0;
virtual std::vector<UUID> findAllImpl(std::type_index type) const = 0;
virtual bool existsImpl(const UUID & id) const = 0;
virtual AccessEntityPtr readImpl(const UUID & id) const = 0;
virtual String readNameImpl(const UUID & id) const = 0;
virtual UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) = 0;
virtual void removeImpl(const UUID & id) = 0;
virtual void updateImpl(const UUID & id, const UpdateFunc & update_func) = 0;
virtual SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0;
virtual SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const = 0;
virtual bool hasSubscriptionImpl(const UUID & id) const = 0;
virtual bool hasSubscriptionImpl(std::type_index type) const = 0;
static UUID generateRandomID();
Poco::Logger * getLogger() const;
static String getTypeName(std::type_index type) { return IAccessEntity::getTypeName(type); }
[[noreturn]] void throwNotFound(const UUID & id) const;
[[noreturn]] void throwNotFound(std::type_index type, const String & name) const;
[[noreturn]] void throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) const;
[[noreturn]] void throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const;
[[noreturn]] void throwNameCollisionCannotInsert(std::type_index type, const String & name) const;
[[noreturn]] void throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const;
[[noreturn]] void throwReadonlyCannotInsert(std::type_index type, const String & name) const;
[[noreturn]] void throwReadonlyCannotUpdate(std::type_index type, const String & name) const;
[[noreturn]] void throwReadonlyCannotRemove(std::type_index type, const String & name) const;
using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>;
using Notifications = std::vector<Notification>;
static void notify(const Notifications & notifications);
private:
AccessEntityPtr tryReadBase(const UUID & id) const;
const String storage_name;
mutable std::atomic<Poco::Logger *> log = nullptr;
};
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::read(const UUID & id) const
{
auto entity = readImpl(id);
auto ptr = typeid_cast<std::shared_ptr<const EntityType>>(entity);
if (ptr)
return ptr;
throwBadCast(id, entity->getType(), entity->getFullName(), typeid(EntityType));
}
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::read(const String & name) const
{
return read<EntityType>(getID<EntityType>(name));
}
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const UUID & id) const
{
auto entity = tryReadBase(id);
if (!entity)
return nullptr;
return typeid_cast<std::shared_ptr<const EntityType>>(entity);
}
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const String & name) const
{
auto id = find<EntityType>(name);
return id ? tryRead<EntityType>(*id) : nullptr;
}
}

View File

@ -0,0 +1,358 @@
#include <Access/MemoryAccessStorage.h>
#include <ext/scope_guard.h>
#include <unordered_set>
namespace DB
{
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_)
: IAccessStorage(storage_name_), shared_ptr_to_this{std::make_shared<const MemoryAccessStorage *>(this)}
{
}
MemoryAccessStorage::~MemoryAccessStorage() {}
std::optional<UUID> MemoryAccessStorage::findImpl(std::type_index type, const String & name) const
{
std::lock_guard lock{mutex};
auto it = names.find({name, type});
if (it == names.end())
return {};
Entry & entry = *(it->second);
return entry.id;
}
std::vector<UUID> MemoryAccessStorage::findAllImpl(std::type_index type) const
{
std::lock_guard lock{mutex};
std::vector<UUID> result;
result.reserve(entries.size());
for (const auto & [id, entry] : entries)
if (entry.entity->isTypeOf(type))
result.emplace_back(id);
return result;
}
bool MemoryAccessStorage::existsImpl(const UUID & id) const
{
std::lock_guard lock{mutex};
return entries.count(id);
}
AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries.find(id);
if (it == entries.end())
throwNotFound(id);
const Entry & entry = it->second;
return entry.entity;
}
String MemoryAccessStorage::readNameImpl(const UUID & id) const
{
return readImpl(id)->getFullName();
}
UUID MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID();
std::lock_guard lock{mutex};
insertNoLock(generateRandomID(), new_entity, replace_if_exists, notifications);
return id;
}
void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications)
{
const String & name = new_entity->getFullName();
std::type_index type = new_entity->getType();
/// Check that we can insert.
auto it = entries.find(id);
if (it != entries.end())
{
const auto & existing_entry = it->second;
throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getFullName());
}
auto it2 = names.find({name, type});
if (it2 != names.end())
{
const auto & existing_entry = *(it2->second);
if (replace_if_exists)
removeNoLock(existing_entry.id, notifications);
else
throwNameCollisionCannotInsert(type, name);
}
/// Do insertion.
auto & entry = entries[id];
entry.id = id;
entry.entity = new_entity;
names[std::pair{name, type}] = &entry;
prepareNotifications(entry, false, notifications);
}
void MemoryAccessStorage::removeImpl(const UUID & id)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
removeNoLock(id, notifications);
}
void MemoryAccessStorage::removeNoLock(const UUID & id, Notifications & notifications)
{
auto it = entries.find(id);
if (it == entries.end())
throwNotFound(id);
Entry & entry = it->second;
const String & name = entry.entity->getFullName();
std::type_index type = entry.entity->getType();
prepareNotifications(entry, true, notifications);
/// Do removing.
names.erase({name, type});
entries.erase(it);
}
void MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
updateNoLock(id, update_func, notifications);
}
void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications)
{
auto it = entries.find(id);
if (it == entries.end())
throwNotFound(id);
Entry & entry = it->second;
auto old_entity = entry.entity;
auto new_entity = update_func(old_entity);
if (*new_entity == *old_entity)
return;
entry.entity = new_entity;
if (new_entity->getFullName() != old_entity->getFullName())
{
auto it2 = names.find({new_entity->getFullName(), new_entity->getType()});
if (it2 != names.end())
throwNameCollisionCannotRename(old_entity->getType(), old_entity->getFullName(), new_entity->getFullName());
names.erase({old_entity->getFullName(), old_entity->getType()});
names[std::pair{new_entity->getFullName(), new_entity->getType()}] = &entry;
}
prepareNotifications(entry, false, notifications);
}
void MemoryAccessStorage::setAll(const std::vector<AccessEntityPtr> & all_entities)
{
std::vector<std::pair<UUID, AccessEntityPtr>> entities_with_ids;
entities_with_ids.reserve(all_entities.size());
for (const auto & entity : all_entities)
entities_with_ids.emplace_back(generateRandomID(), entity);
setAll(entities_with_ids);
}
void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
setAllNoLock(all_entities, notifications);
}
void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications)
{
/// Get list of the currently used IDs. Later we will remove those of them which are not used anymore.
std::unordered_set<UUID> not_used_ids;
for (const auto & id_and_entry : entries)
not_used_ids.emplace(id_and_entry.first);
/// Remove conflicting entities.
for (const auto & [id, entity] : all_entities)
{
auto it = entries.find(id);
if (it != entries.end())
{
not_used_ids.erase(id); /// ID is used.
Entry & entry = it->second;
if (entry.entity->getType() != entity->getType())
{
removeNoLock(id, notifications);
continue;
}
}
auto it2 = names.find({entity->getFullName(), entity->getType()});
if (it2 != names.end())
{
Entry & entry = *(it2->second);
if (entry.id != id)
removeNoLock(id, notifications);
}
}
/// Remove entities which are not used anymore.
for (const auto & id : not_used_ids)
removeNoLock(id, notifications);
/// Insert or update entities.
for (const auto & [id, entity] : all_entities)
{
auto it = entries.find(id);
if (it != entries.end())
{
if (*(it->second.entity) != *entity)
{
const AccessEntityPtr & changed_entity = entity;
updateNoLock(id, [&changed_entity](const AccessEntityPtr &) { return changed_entity; }, notifications);
}
}
else
insertNoLock(id, entity, false, notifications);
}
}
void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, remove ? nullptr : entry.entity});
auto range = handlers_by_type.equal_range(entry.entity->getType());
for (auto it = range.first; it != range.second; ++it)
notifications.push_back({it->second, entry.id, remove ? nullptr : entry.entity});
}
IAccessStorage::SubscriptionPtr MemoryAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
{
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(
const MemoryAccessStorage & storage_,
std::type_index type_,
const OnChangedHandler & handler_)
: storage_weak(storage_.shared_ptr_to_this)
{
std::lock_guard lock{storage_.mutex};
handler_it = storage_.handlers_by_type.emplace(type_, handler_);
}
~SubscriptionImpl() override
{
auto storage = storage_weak.lock();
if (storage)
{
std::lock_guard lock{(*storage)->mutex};
(*storage)->handlers_by_type.erase(handler_it);
}
}
private:
std::weak_ptr<const MemoryAccessStorage *> storage_weak;
std::unordered_multimap<std::type_index, OnChangedHandler>::iterator handler_it;
};
return std::make_unique<SubscriptionImpl>(*this, type, handler);
}
IAccessStorage::SubscriptionPtr MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(
const MemoryAccessStorage & storage_,
const UUID & id_,
const OnChangedHandler & handler_)
: storage_weak(storage_.shared_ptr_to_this),
id(id_)
{
std::lock_guard lock{storage_.mutex};
auto it = storage_.entries.find(id);
if (it == storage_.entries.end())
{
storage_weak.reset();
return;
}
const Entry & entry = it->second;
handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler_);
}
~SubscriptionImpl() override
{
auto storage = storage_weak.lock();
if (storage)
{
std::lock_guard lock{(*storage)->mutex};
auto it = (*storage)->entries.find(id);
if (it != (*storage)->entries.end())
{
const Entry & entry = it->second;
entry.handlers_by_id.erase(handler_it);
}
}
}
private:
std::weak_ptr<const MemoryAccessStorage *> storage_weak;
UUID id;
std::list<OnChangedHandler>::iterator handler_it;
};
return std::make_unique<SubscriptionImpl>(*this, id, handler);
}
bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const
{
auto it = entries.find(id);
if (it != entries.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool MemoryAccessStorage::hasSubscriptionImpl(std::type_index type) const
{
auto range = handlers_by_type.equal_range(type);
return range.first != range.second;
}
}

View File

@ -0,0 +1,65 @@
#pragma once
#include <Access/IAccessStorage.h>
#include <list>
#include <memory>
#include <mutex>
#include <unordered_map>
namespace DB
{
/// Implementation of IAccessStorage which keeps all data in memory.
class MemoryAccessStorage : public IAccessStorage
{
public:
MemoryAccessStorage(const String & storage_name_ = "memory");
~MemoryAccessStorage() override;
/// Sets all entities at once.
void setAll(const std::vector<AccessEntityPtr> & all_entities);
void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities);
private:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override;
bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID & id) const override;
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override;
struct Entry
{
UUID id;
AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
};
void insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, Notifications & notifications);
void removeNoLock(const UUID & id, Notifications & notifications);
void updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications);
void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications);
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
using NameTypePair = std::pair<String, std::type_index>;
struct Hash
{
size_t operator()(const NameTypePair & key) const
{
return std::hash<String>{}(key.first) - std::hash<std::type_index>{}(key.second);
}
};
mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries; /// We want to search entries both by ID and by the pair of name and type.
std::unordered_map<NameTypePair, Entry *, Hash> names; /// and by the pair of name and type.
mutable std::unordered_multimap<std::type_index, OnChangedHandler> handlers_by_type;
std::shared_ptr<const MemoryAccessStorage *> shared_ptr_to_this; /// We need weak pointers to `this` to implement subscriptions.
};
}

View File

@ -0,0 +1,246 @@
#include <Access/MultipleAccessStorage.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ACCESS_ENTITY_NOT_FOUND;
extern const int ACCESS_ENTITY_FOUND_DUPLICATES;
}
namespace
{
template <typename StoragePtrT>
String joinStorageNames(const std::vector<StoragePtrT> & storages)
{
String result;
for (const auto & storage : storages)
{
if (!result.empty())
result += ", ";
result += storage->getStorageName();
}
return result;
}
}
MultipleAccessStorage::MultipleAccessStorage(
std::vector<std::unique_ptr<Storage>> nested_storages_, size_t index_of_nested_storage_for_insertion_)
: IAccessStorage(joinStorageNames(nested_storages_))
, nested_storages(std::move(nested_storages_))
, nested_storage_for_insertion(nested_storages[index_of_nested_storage_for_insertion_].get())
, ids_cache(512 /* cache size */)
{
}
MultipleAccessStorage::~MultipleAccessStorage()
{
}
std::vector<UUID> MultipleAccessStorage::findMultiple(std::type_index type, const String & name) const
{
std::vector<UUID> ids;
for (const auto & nested_storage : nested_storages)
{
auto id = nested_storage->find(type, name);
if (id)
{
std::lock_guard lock{ids_cache_mutex};
ids_cache.set(*id, std::make_shared<Storage *>(nested_storage.get()));
ids.push_back(*id);
}
}
return ids;
}
std::optional<UUID> MultipleAccessStorage::findImpl(std::type_index type, const String & name) const
{
auto ids = findMultiple(type, name);
if (ids.empty())
return {};
if (ids.size() == 1)
return ids[0];
std::vector<const Storage *> storages_with_duplicates;
for (const auto & id : ids)
{
auto * storage = findStorage(id);
if (storage)
storages_with_duplicates.push_back(storage);
}
throw Exception(
"Found " + getTypeName(type) + " " + backQuote(name) + " in " + std::to_string(ids.size())
+ " storages: " + joinStorageNames(storages_with_duplicates),
ErrorCodes::ACCESS_ENTITY_FOUND_DUPLICATES);
}
std::vector<UUID> MultipleAccessStorage::findAllImpl(std::type_index type) const
{
std::vector<UUID> all_ids;
for (const auto & nested_storage : nested_storages)
{
auto ids = nested_storage->findAll(type);
all_ids.insert(all_ids.end(), std::make_move_iterator(ids.begin()), std::make_move_iterator(ids.end()));
}
return all_ids;
}
bool MultipleAccessStorage::existsImpl(const UUID & id) const
{
return findStorage(id) != nullptr;
}
IAccessStorage * MultipleAccessStorage::findStorage(const UUID & id)
{
{
std::lock_guard lock{ids_cache_mutex};
auto from_cache = ids_cache.get(id);
if (from_cache)
{
auto * storage = *from_cache;
if (storage->exists(id))
return storage;
}
}
for (const auto & nested_storage : nested_storages)
{
if (nested_storage->exists(id))
{
std::lock_guard lock{ids_cache_mutex};
ids_cache.set(id, std::make_shared<Storage *>(nested_storage.get()));
return nested_storage.get();
}
}
return nullptr;
}
const IAccessStorage * MultipleAccessStorage::findStorage(const UUID & id) const
{
return const_cast<MultipleAccessStorage *>(this)->findStorage(id);
}
IAccessStorage & MultipleAccessStorage::getStorage(const UUID & id)
{
auto * storage = findStorage(id);
if (storage)
return *storage;
throwNotFound(id);
}
const IAccessStorage & MultipleAccessStorage::getStorage(const UUID & id) const
{
return const_cast<MultipleAccessStorage *>(this)->getStorage(id);
}
AccessEntityPtr MultipleAccessStorage::readImpl(const UUID & id) const
{
return getStorage(id).read(id);
}
String MultipleAccessStorage::readNameImpl(const UUID & id) const
{
return getStorage(id).readName(id);
}
UUID MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists)
{
auto id = replace_if_exists ? nested_storage_for_insertion->insertOrReplace(entity) : nested_storage_for_insertion->insert(entity);
std::lock_guard lock{ids_cache_mutex};
ids_cache.set(id, std::make_shared<Storage *>(nested_storage_for_insertion));
return id;
}
void MultipleAccessStorage::removeImpl(const UUID & id)
{
getStorage(id).remove(id);
}
void MultipleAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func)
{
getStorage(id).update(id, update_func);
}
IAccessStorage::SubscriptionPtr MultipleAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
auto storage = findStorage(id);
if (!storage)
return nullptr;
return storage->subscribeForChanges(id, handler);
}
IAccessStorage::SubscriptionPtr MultipleAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
{
std::vector<SubscriptionPtr> subscriptions;
for (const auto & nested_storage : nested_storages)
{
auto subscription = nested_storage->subscribeForChanges(type, handler);
if (subscription)
subscriptions.emplace_back(std::move(subscription));
}
if (subscriptions.empty())
return nullptr;
if (subscriptions.size() == 1)
return std::move(subscriptions[0]);
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(std::vector<SubscriptionPtr> subscriptions_)
: subscriptions(std::move(subscriptions_)) {}
private:
std::vector<SubscriptionPtr> subscriptions;
};
return std::make_unique<SubscriptionImpl>(std::move(subscriptions));
}
bool MultipleAccessStorage::hasSubscriptionImpl(const UUID & id) const
{
for (const auto & nested_storage : nested_storages)
{
if (nested_storage->hasSubscription(id))
return true;
}
return false;
}
bool MultipleAccessStorage::hasSubscriptionImpl(std::type_index type) const
{
for (const auto & nested_storage : nested_storages)
{
if (nested_storage->hasSubscription(type))
return true;
}
return false;
}
}

View File

@ -0,0 +1,53 @@
#pragma once
#include <Access/IAccessStorage.h>
#include <Common/LRUCache.h>
#include <mutex>
namespace DB
{
/// Implementation of IAccessStorage which contains multiple nested storages.
class MultipleAccessStorage : public IAccessStorage
{
public:
using Storage = IAccessStorage;
MultipleAccessStorage(std::vector<std::unique_ptr<Storage>> nested_storages_, size_t index_of_nested_storage_for_insertion_ = 0);
~MultipleAccessStorage() override;
std::vector<UUID> findMultiple(std::type_index type, const String & name) const;
template <typename EntityType>
std::vector<UUID> findMultiple(const String & name) const { return findMultiple(EntityType::TYPE, name); }
const Storage * findStorage(const UUID & id) const;
Storage * findStorage(const UUID & id);
const Storage & getStorage(const UUID & id) const;
Storage & getStorage(const UUID & id);
Storage & getStorageByIndex(size_t i) { return *(nested_storages[i]); }
const Storage & getStorageByIndex(size_t i) const { return *(nested_storages[i]); }
protected:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override;
bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID &id) const override;
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override;
private:
std::vector<std::unique_ptr<Storage>> nested_storages;
IAccessStorage * nested_storage_for_insertion;
mutable LRUCache<UUID, Storage *> ids_cache;
mutable std::mutex ids_cache_mutex;
};
}

46
dbms/src/Access/Quota.cpp Normal file
View File

@ -0,0 +1,46 @@
#include <Access/Quota.h>
#include <boost/range/algorithm/equal.hpp>
#include <boost/range/algorithm/fill.hpp>
namespace DB
{
Quota::Limits::Limits()
{
boost::range::fill(max, 0);
}
bool operator ==(const Quota::Limits & lhs, const Quota::Limits & rhs)
{
return boost::range::equal(lhs.max, rhs.max) && (lhs.duration == rhs.duration)
&& (lhs.randomize_interval == rhs.randomize_interval);
}
bool Quota::equal(const IAccessEntity & other) const
{
if (!IAccessEntity::equal(other))
return false;
const auto & other_quota = typeid_cast<const Quota &>(other);
return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles)
&& (all_roles == other_quota.all_roles) && (except_roles == other_quota.except_roles);
}
const char * Quota::resourceTypeToColumnName(ResourceType resource_type)
{
switch (resource_type)
{
case Quota::QUERIES: return "queries";
case Quota::ERRORS: return "errors";
case Quota::RESULT_ROWS: return "result_rows";
case Quota::RESULT_BYTES: return "result_bytes";
case Quota::READ_ROWS: return "read_rows";
case Quota::READ_BYTES: return "read_bytes";
case Quota::EXECUTION_TIME: return "execution_time";
}
__builtin_unreachable();
}
}

141
dbms/src/Access/Quota.h Normal file
View File

@ -0,0 +1,141 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <chrono>
namespace DB
{
/** Quota for resources consumption for specific interval.
* Used to limit resource usage by user.
* Quota is applied "softly" - could be slightly exceed, because it is checked usually only on each block of processed data.
* Accumulated values are not persisted and are lost on server restart.
* Quota is local to server,
* but for distributed queries, accumulated values for read rows and bytes
* are collected from all participating servers and accumulated locally.
*/
struct Quota : public IAccessEntity
{
enum ResourceType
{
QUERIES, /// Number of queries.
ERRORS, /// Number of queries with exceptions.
RESULT_ROWS, /// Number of rows returned as result.
RESULT_BYTES, /// Number of bytes returned as result.
READ_ROWS, /// Number of rows read from tables.
READ_BYTES, /// Number of bytes read from tables.
EXECUTION_TIME, /// Total amount of query execution time in nanoseconds.
};
static constexpr size_t MAX_RESOURCE_TYPE = 7;
using ResourceAmount = UInt64;
static constexpr ResourceAmount UNLIMITED = 0; /// 0 means unlimited.
/// Amount of resources available to consume for each duration.
struct Limits
{
ResourceAmount max[MAX_RESOURCE_TYPE];
std::chrono::seconds duration = std::chrono::seconds::zero();
/// Intervals can be randomized (to avoid DoS if intervals for many users end at one time).
bool randomize_interval = false;
Limits();
friend bool operator ==(const Limits & lhs, const Limits & rhs);
friend bool operator !=(const Limits & lhs, const Limits & rhs) { return !(lhs == rhs); }
};
std::vector<Limits> all_limits;
/// Key to share quota consumption.
/// Users with the same key share the same amount of resource.
enum class KeyType
{
NONE, /// All users share the same quota.
USER_NAME, /// Connections with the same user name share the same quota.
IP_ADDRESS, /// Connections from the same IP share the same quota.
CLIENT_KEY, /// Client should explicitly supply a key to use.
CLIENT_KEY_OR_USER_NAME, /// Same as CLIENT_KEY, but use USER_NAME if the client doesn't supply a key.
CLIENT_KEY_OR_IP_ADDRESS, /// Same as CLIENT_KEY, but use IP_ADDRESS if the client doesn't supply a key.
};
static constexpr size_t MAX_KEY_TYPE = 6;
KeyType key_type = KeyType::NONE;
/// Which roles or users should use this quota.
Strings roles;
bool all_roles = false;
Strings except_roles;
bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); }
static const char * getNameOfResourceType(ResourceType resource_type);
static const char * resourceTypeToKeyword(ResourceType resource_type);
static const char * resourceTypeToColumnName(ResourceType resource_type);
static const char * getNameOfKeyType(KeyType key_type);
static double executionTimeToSeconds(ResourceAmount ns);
static ResourceAmount secondsToExecutionTime(double s);
};
inline const char * Quota::getNameOfResourceType(ResourceType resource_type)
{
switch (resource_type)
{
case Quota::QUERIES: return "queries";
case Quota::ERRORS: return "errors";
case Quota::RESULT_ROWS: return "result rows";
case Quota::RESULT_BYTES: return "result bytes";
case Quota::READ_ROWS: return "read rows";
case Quota::READ_BYTES: return "read bytes";
case Quota::EXECUTION_TIME: return "execution time";
}
__builtin_unreachable();
}
inline const char * Quota::resourceTypeToKeyword(ResourceType resource_type)
{
switch (resource_type)
{
case Quota::QUERIES: return "QUERIES";
case Quota::ERRORS: return "ERRORS";
case Quota::RESULT_ROWS: return "RESULT ROWS";
case Quota::RESULT_BYTES: return "RESULT BYTES";
case Quota::READ_ROWS: return "READ ROWS";
case Quota::READ_BYTES: return "READ BYTES";
case Quota::EXECUTION_TIME: return "EXECUTION TIME";
}
__builtin_unreachable();
}
inline const char * Quota::getNameOfKeyType(KeyType key_type)
{
switch (key_type)
{
case KeyType::NONE: return "none";
case KeyType::USER_NAME: return "user name";
case KeyType::IP_ADDRESS: return "ip address";
case KeyType::CLIENT_KEY: return "client key";
case KeyType::CLIENT_KEY_OR_USER_NAME: return "client key or user name";
case KeyType::CLIENT_KEY_OR_IP_ADDRESS: return "client key or ip address";
}
__builtin_unreachable();
}
inline double Quota::executionTimeToSeconds(ResourceAmount ns)
{
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::nanoseconds{ns}).count();
}
inline Quota::ResourceAmount Quota::secondsToExecutionTime(double s)
{
return std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::duration<double>(s)).count();
}
using QuotaPtr = std::shared_ptr<const Quota>;
}

View File

@ -0,0 +1,264 @@
#include <Access/QuotaContext.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <ext/chrono_io.h>
#include <ext/range.h>
#include <boost/range/algorithm/fill.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int QUOTA_EXPIRED;
}
struct QuotaContext::Impl
{
[[noreturn]] static void throwQuotaExceed(
const String & user_name,
const String & quota_name,
ResourceType resource_type,
ResourceAmount used,
ResourceAmount max,
std::chrono::seconds duration,
std::chrono::system_clock::time_point end_of_interval)
{
std::function<String(UInt64)> amount_to_string = [](UInt64 amount) { return std::to_string(amount); };
if (resource_type == Quota::EXECUTION_TIME)
amount_to_string = [&](UInt64 amount) { return ext::to_string(std::chrono::nanoseconds(amount)); };
throw Exception(
"Quota for user " + backQuote(user_name) + " for " + ext::to_string(duration) + " has been exceeded: "
+ Quota::getNameOfResourceType(resource_type) + " = " + amount_to_string(used) + "/" + amount_to_string(max) + ". "
+ "Interval will end at " + ext::to_string(end_of_interval) + ". " + "Name of quota template: " + backQuote(quota_name),
ErrorCodes::QUOTA_EXPIRED);
}
static std::chrono::system_clock::time_point getEndOfInterval(
const Interval & interval, std::chrono::system_clock::time_point current_time, bool * counters_were_reset = nullptr)
{
auto & end_of_interval = interval.end_of_interval;
auto end_loaded = end_of_interval.load();
auto end = std::chrono::system_clock::time_point{end_loaded};
if (current_time < end)
{
if (counters_were_reset)
*counters_were_reset = false;
return end;
}
const auto duration = interval.duration;
do
{
end = end + (current_time - end + duration) / duration * duration;
if (end_of_interval.compare_exchange_strong(end_loaded, end.time_since_epoch()))
{
boost::range::fill(interval.used, 0);
break;
}
end = std::chrono::system_clock::time_point{end_loaded};
}
while (current_time >= end);
if (counters_were_reset)
*counters_were_reset = true;
return end;
}
static void used(
const String & user_name,
const Intervals & intervals,
ResourceType resource_type,
ResourceAmount amount,
std::chrono::system_clock::time_point current_time,
bool check_exceeded)
{
for (const auto & interval : intervals.intervals)
{
ResourceAmount used = (interval.used[resource_type] += amount);
ResourceAmount max = interval.max[resource_type];
if (max == Quota::UNLIMITED)
continue;
if (used > max)
{
bool counters_were_reset = false;
auto end_of_interval = getEndOfInterval(interval, current_time, &counters_were_reset);
if (counters_were_reset)
{
used = (interval.used[resource_type] += amount);
if ((used > max) && check_exceeded)
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
}
else if (check_exceeded)
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
}
}
}
static void checkExceeded(
const String & user_name,
const Intervals & intervals,
ResourceType resource_type,
std::chrono::system_clock::time_point current_time)
{
for (const auto & interval : intervals.intervals)
{
ResourceAmount used = interval.used[resource_type];
ResourceAmount max = interval.max[resource_type];
if (max == Quota::UNLIMITED)
continue;
if (used > max)
{
bool used_counters_reset = false;
std::chrono::system_clock::time_point end_of_interval = getEndOfInterval(interval, current_time, &used_counters_reset);
if (!used_counters_reset)
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
}
}
}
static void checkExceeded(
const String & user_name,
const Intervals & intervals,
std::chrono::system_clock::time_point current_time)
{
for (auto resource_type : ext::range_with_static_cast<Quota::ResourceType>(Quota::MAX_RESOURCE_TYPE))
checkExceeded(user_name, intervals, resource_type, current_time);
}
};
QuotaContext::Interval & QuotaContext::Interval::operator =(const Interval & src)
{
randomize_interval = src.randomize_interval;
duration = src.duration;
end_of_interval.store(src.end_of_interval.load());
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
max[resource_type] = src.max[resource_type];
used[resource_type].store(src.used[resource_type].load());
}
return *this;
}
QuotaUsageInfo QuotaContext::Intervals::getUsageInfo(std::chrono::system_clock::time_point current_time) const
{
QuotaUsageInfo info;
info.quota_id = quota_id;
info.quota_name = quota_name;
info.quota_key = quota_key;
info.intervals.reserve(intervals.size());
for (const auto & in : intervals)
{
info.intervals.push_back({});
auto & out = info.intervals.back();
out.duration = in.duration;
out.randomize_interval = in.randomize_interval;
out.end_of_interval = Impl::getEndOfInterval(in, current_time);
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
out.max[resource_type] = in.max[resource_type];
out.used[resource_type] = in.used[resource_type];
}
}
return info;
}
QuotaContext::QuotaContext()
: atomic_intervals(std::make_shared<Intervals>()) /// Unlimited quota.
{
}
QuotaContext::QuotaContext(
const String & user_name_,
const Poco::Net::IPAddress & address_,
const String & client_key_)
: user_name(user_name_), address(address_), client_key(client_key_)
{
}
QuotaContext::~QuotaContext() = default;
void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded)
{
used({resource_type, amount}, check_exceeded);
}
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded);
}
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded);
}
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource3.first, resource3.second, current_time, check_exceeded);
}
void QuotaContext::used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
for (const auto & resource : resources)
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded);
}
void QuotaContext::checkExceeded()
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
Impl::checkExceeded(user_name, *intervals_ptr, std::chrono::system_clock::now());
}
void QuotaContext::checkExceeded(ResourceType resource_type)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
Impl::checkExceeded(user_name, *intervals_ptr, resource_type, std::chrono::system_clock::now());
}
QuotaUsageInfo QuotaContext::getUsageInfo() const
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
return intervals_ptr->getUsageInfo(std::chrono::system_clock::now());
}
QuotaUsageInfo::QuotaUsageInfo() : quota_id(UUID(UInt128(0)))
{
}
QuotaUsageInfo::Interval::Interval()
{
boost::range::fill(used, 0);
boost::range::fill(max, 0);
}
}

View File

@ -0,0 +1,110 @@
#pragma once
#include <Access/Quota.h>
#include <Core/UUID.h>
#include <Poco/Net/IPAddress.h>
#include <ext/shared_ptr_helper.h>
#include <boost/noncopyable.hpp>
#include <atomic>
#include <chrono>
#include <memory>
namespace DB
{
struct QuotaUsageInfo;
/// Instances of `QuotaContext` are used to track resource consumption.
class QuotaContext : public boost::noncopyable
{
public:
using ResourceType = Quota::ResourceType;
using ResourceAmount = Quota::ResourceAmount;
/// Default constructors makes an unlimited quota.
QuotaContext();
~QuotaContext();
/// Tracks resource consumption. If the quota exceeded and `check_exceeded == true`, throws an exception.
void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true);
void used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded = true);
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded = true);
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded = true);
void used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded = true);
/// Checks if the quota exceeded. If so, throws an exception.
void checkExceeded();
void checkExceeded(ResourceType resource_type);
/// Returns the information about this quota context.
QuotaUsageInfo getUsageInfo() const;
private:
friend class QuotaContextFactory;
friend struct ext::shared_ptr_helper<QuotaContext>;
/// Instances of this class are created by QuotaContextFactory.
QuotaContext(const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_);
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
struct Interval
{
mutable std::atomic<ResourceAmount> used[MAX_RESOURCE_TYPE];
ResourceAmount max[MAX_RESOURCE_TYPE];
std::chrono::seconds duration;
bool randomize_interval;
mutable std::atomic<std::chrono::system_clock::duration> end_of_interval;
Interval() {}
Interval(const Interval & src) { *this = src; }
Interval & operator =(const Interval & src);
};
struct Intervals
{
std::vector<Interval> intervals;
UUID quota_id;
String quota_name;
String quota_key;
QuotaUsageInfo getUsageInfo(std::chrono::system_clock::time_point current_time) const;
};
struct Impl;
const String user_name;
const Poco::Net::IPAddress address;
const String client_key;
std::shared_ptr<const Intervals> atomic_intervals; /// atomically changed by QuotaUsageManager
};
using QuotaContextPtr = std::shared_ptr<QuotaContext>;
/// The information about a quota context.
struct QuotaUsageInfo
{
using ResourceType = Quota::ResourceType;
using ResourceAmount = Quota::ResourceAmount;
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
struct Interval
{
ResourceAmount used[MAX_RESOURCE_TYPE];
ResourceAmount max[MAX_RESOURCE_TYPE];
std::chrono::seconds duration = std::chrono::seconds::zero();
bool randomize_interval = false;
std::chrono::system_clock::time_point end_of_interval;
Interval();
};
std::vector<Interval> intervals;
UUID quota_id;
String quota_name;
String quota_key;
QuotaUsageInfo();
};
}

View File

@ -0,0 +1,299 @@
#include <Access/QuotaContext.h>
#include <Access/QuotaContextFactory.h>
#include <Access/AccessControlManager.h>
#include <Common/Exception.h>
#include <Common/thread_local_rng.h>
#include <ext/range.h>
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/algorithm/lower_bound.hpp>
#include <boost/range/algorithm/stable_sort.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int QUOTA_REQUIRES_CLIENT_KEY;
}
namespace
{
std::chrono::system_clock::duration randomDuration(std::chrono::seconds max)
{
auto count = std::chrono::duration_cast<std::chrono::system_clock::duration>(max).count();
std::uniform_int_distribution<Int64> distribution{0, count - 1};
return std::chrono::system_clock::duration(distribution(thread_local_rng));
}
}
void QuotaContextFactory::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUID & quota_id_)
{
quota = quota_;
quota_id = quota_id_;
boost::range::copy(quota->roles, std::inserter(roles, roles.end()));
all_roles = quota->all_roles;
boost::range::copy(quota->except_roles, std::inserter(except_roles, except_roles.end()));
rebuildAllIntervals();
}
bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const
{
if (roles.count(context.user_name))
return true;
if (all_roles && !except_roles.count(context.user_name))
return true;
return false;
}
String QuotaContextFactory::QuotaInfo::calculateKey(const QuotaContext & context) const
{
using KeyType = Quota::KeyType;
switch (quota->key_type)
{
case KeyType::NONE:
return "";
case KeyType::USER_NAME:
return context.user_name;
case KeyType::IP_ADDRESS:
return context.address.toString();
case KeyType::CLIENT_KEY:
{
if (!context.client_key.empty())
return context.client_key;
throw Exception(
"Quota " + quota->getName() + " (for user " + context.user_name + ") requires a client supplied key.",
ErrorCodes::QUOTA_REQUIRES_CLIENT_KEY);
}
case KeyType::CLIENT_KEY_OR_USER_NAME:
{
if (!context.client_key.empty())
return context.client_key;
return context.user_name;
}
case KeyType::CLIENT_KEY_OR_IP_ADDRESS:
{
if (!context.client_key.empty())
return context.client_key;
return context.address.toString();
}
}
__builtin_unreachable();
}
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key)
{
auto it = key_to_intervals.find(key);
if (it != key_to_intervals.end())
return it->second;
return rebuildIntervals(key);
}
void QuotaContextFactory::QuotaInfo::rebuildAllIntervals()
{
for (const String & key : key_to_intervals | boost::adaptors::map_keys)
rebuildIntervals(key);
}
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key)
{
auto new_intervals = std::make_shared<Intervals>();
new_intervals->quota_name = quota->getName();
new_intervals->quota_id = quota_id;
new_intervals->quota_key = key;
auto & intervals = new_intervals->intervals;
intervals.reserve(quota->all_limits.size());
constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
for (const auto & limits : quota->all_limits)
{
intervals.emplace_back();
auto & interval = intervals.back();
interval.duration = limits.duration;
std::chrono::system_clock::time_point end_of_interval{};
interval.randomize_interval = limits.randomize_interval;
if (limits.randomize_interval)
end_of_interval += randomDuration(limits.duration);
interval.end_of_interval = end_of_interval.time_since_epoch();
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
interval.max[resource_type] = limits.max[resource_type];
interval.used[resource_type] = 0;
}
}
/// Order intervals by durations from largest to smallest.
/// To report first about largest interval on what quota was exceeded.
struct GreaterByDuration
{
bool operator()(const Interval & lhs, const Interval & rhs) const { return lhs.duration > rhs.duration; }
};
boost::range::stable_sort(intervals, GreaterByDuration{});
auto it = key_to_intervals.find(key);
if (it == key_to_intervals.end())
{
/// Just put new intervals into the map.
key_to_intervals.try_emplace(key, new_intervals);
}
else
{
/// We need to keep usage information from the old intervals.
const auto & old_intervals = it->second->intervals;
for (auto & new_interval : new_intervals->intervals)
{
/// Check if an interval with the same duration is already in use.
auto lower_bound = boost::range::lower_bound(old_intervals, new_interval, GreaterByDuration{});
if ((lower_bound == old_intervals.end()) || (lower_bound->duration != new_interval.duration))
continue;
/// Found an interval with the same duration, we need to copy its usage information to `result`.
auto & current_interval = *lower_bound;
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
new_interval.used[resource_type].store(current_interval.used[resource_type].load());
new_interval.end_of_interval.store(current_interval.end_of_interval.load());
}
}
it->second = new_intervals;
}
return new_intervals;
}
QuotaContextFactory::QuotaContextFactory(const AccessControlManager & access_control_manager_)
: access_control_manager(access_control_manager_)
{
}
QuotaContextFactory::~QuotaContextFactory()
{
}
std::shared_ptr<QuotaContext> QuotaContextFactory::createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key)
{
std::lock_guard lock{mutex};
ensureAllQuotasRead();
auto context = ext::shared_ptr_helper<QuotaContext>::create(user_name, address, client_key);
contexts.push_back(context);
chooseQuotaForContext(context);
return context;
}
void QuotaContextFactory::ensureAllQuotasRead()
{
/// `mutex` is already locked.
if (all_quotas_read)
return;
all_quotas_read = true;
subscription = access_control_manager.subscribeForChanges<Quota>(
[&](const UUID & id, const AccessEntityPtr & entity)
{
if (entity)
quotaAddedOrChanged(id, typeid_cast<QuotaPtr>(entity));
else
quotaRemoved(id);
});
for (const UUID & quota_id : access_control_manager.findAll<Quota>())
{
auto quota = access_control_manager.tryRead<Quota>(quota_id);
if (quota)
all_quotas.emplace(quota_id, QuotaInfo(quota, quota_id));
}
}
void QuotaContextFactory::quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr<const Quota> & new_quota)
{
std::lock_guard lock{mutex};
auto it = all_quotas.find(quota_id);
if (it == all_quotas.end())
{
it = all_quotas.emplace(quota_id, QuotaInfo(new_quota, quota_id)).first;
}
else
{
if (it->second.quota == new_quota)
return;
}
auto & info = it->second;
info.setQuota(new_quota, quota_id);
chooseQuotaForAllContexts();
}
void QuotaContextFactory::quotaRemoved(const UUID & quota_id)
{
std::lock_guard lock{mutex};
all_quotas.erase(quota_id);
chooseQuotaForAllContexts();
}
void QuotaContextFactory::chooseQuotaForAllContexts()
{
/// `mutex` is already locked.
boost::range::remove_erase_if(
contexts,
[&](const std::weak_ptr<QuotaContext> & weak)
{
auto context = weak.lock();
if (!context)
return true; // remove from the `contexts` list.
chooseQuotaForContext(context);
return false; // keep in the `contexts` list.
});
}
void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context)
{
/// `mutex` is already locked.
std::shared_ptr<const Intervals> intervals;
for (auto & info : all_quotas | boost::adaptors::map_values)
{
if (info.canUseWithContext(*context))
{
String key = info.calculateKey(*context);
intervals = info.getOrBuildIntervals(key);
break;
}
}
if (!intervals)
intervals = std::make_shared<Intervals>(); /// No quota == no limits.
std::atomic_store(&context->atomic_intervals, intervals);
}
std::vector<QuotaUsageInfo> QuotaContextFactory::getUsageInfo() const
{
std::lock_guard lock{mutex};
std::vector<QuotaUsageInfo> all_infos;
auto current_time = std::chrono::system_clock::now();
for (const auto & info : all_quotas | boost::adaptors::map_values)
{
for (const auto & intervals : info.key_to_intervals | boost::adaptors::map_values)
all_infos.push_back(intervals->getUsageInfo(current_time));
}
return all_infos;
}
}

View File

@ -0,0 +1,62 @@
#pragma once
#include <Access/QuotaContext.h>
#include <Access/IAccessStorage.h>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
namespace DB
{
class AccessControlManager;
/// Stores information how much amount of resources have been consumed and how much are left.
class QuotaContextFactory
{
public:
QuotaContextFactory(const AccessControlManager & access_control_manager_);
~QuotaContextFactory();
QuotaContextPtr createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key);
std::vector<QuotaUsageInfo> getUsageInfo() const;
private:
using Interval = QuotaContext::Interval;
using Intervals = QuotaContext::Intervals;
struct QuotaInfo
{
QuotaInfo(const QuotaPtr & quota_, const UUID & quota_id_) { setQuota(quota_, quota_id_); }
void setQuota(const QuotaPtr & quota_, const UUID & quota_id_);
bool canUseWithContext(const QuotaContext & context) const;
String calculateKey(const QuotaContext & context) const;
std::shared_ptr<const Intervals> getOrBuildIntervals(const String & key);
std::shared_ptr<const Intervals> rebuildIntervals(const String & key);
void rebuildAllIntervals();
QuotaPtr quota;
UUID quota_id;
std::unordered_set<String> roles;
bool all_roles = false;
std::unordered_set<String> except_roles;
std::unordered_map<String /* quota key */, std::shared_ptr<const Intervals>> key_to_intervals;
};
void ensureAllQuotasRead();
void quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr<const Quota> & new_quota);
void quotaRemoved(const UUID & quota_id);
void chooseQuotaForAllContexts();
void chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context);
const AccessControlManager & access_control_manager;
mutable std::mutex mutex;
std::unordered_map<UUID /* quota id */, QuotaInfo> all_quotas;
bool all_quotas_read = false;
IAccessStorage::SubscriptionPtr subscription;
std::vector<std::weak_ptr<QuotaContext>> contexts;
};
}

View File

@ -0,0 +1,207 @@
#include <Access/UsersConfigAccessStorage.h>
#include <Access/Quota.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/quoteString.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Poco/MD5Engine.h>
#include <cstring>
namespace DB
{
namespace
{
char getTypeChar(std::type_index type)
{
if (type == typeid(Quota))
return 'Q';
return 0;
}
UUID generateID(std::type_index type, const String & name)
{
Poco::MD5Engine md5;
md5.update(name);
char type_storage_chars[] = " USRSXML";
type_storage_chars[0] = getTypeChar(type);
md5.update(type_storage_chars, strlen(type_storage_chars));
UUID result;
memcpy(&result, md5.digest().data(), md5.digestLength());
return result;
}
UUID generateID(const IAccessEntity & entity) { return generateID(entity.getType(), entity.getFullName()); }
QuotaPtr parseQuota(const Poco::Util::AbstractConfiguration & config, const String & quota_name, const Strings & user_names)
{
auto quota = std::make_shared<Quota>();
quota->setName(quota_name);
using KeyType = Quota::KeyType;
String quota_config = "quotas." + quota_name;
if (config.has(quota_config + ".keyed_by_ip"))
quota->key_type = KeyType::IP_ADDRESS;
else if (config.has(quota_config + ".keyed"))
quota->key_type = KeyType::CLIENT_KEY_OR_USER_NAME;
else
quota->key_type = KeyType::USER_NAME;
Poco::Util::AbstractConfiguration::Keys interval_keys;
config.keys(quota_config, interval_keys);
for (const String & interval_key : interval_keys)
{
if (!startsWith(interval_key, "interval"))
continue;
String interval_config = quota_config + "." + interval_key;
std::chrono::seconds duration{config.getInt(interval_config + ".duration", 0)};
if (duration.count() <= 0) /// Skip quotas with non-positive duration.
continue;
quota->all_limits.emplace_back();
auto & limits = quota->all_limits.back();
limits.duration = duration;
limits.randomize_interval = config.getBool(interval_config + ".randomize", false);
using ResourceType = Quota::ResourceType;
limits.max[ResourceType::QUERIES] = config.getUInt64(interval_config + ".queries", Quota::UNLIMITED);
limits.max[ResourceType::ERRORS] = config.getUInt64(interval_config + ".errors", Quota::UNLIMITED);
limits.max[ResourceType::RESULT_ROWS] = config.getUInt64(interval_config + ".result_rows", Quota::UNLIMITED);
limits.max[ResourceType::RESULT_BYTES] = config.getUInt64(interval_config + ".result_bytes", Quota::UNLIMITED);
limits.max[ResourceType::READ_ROWS] = config.getUInt64(interval_config + ".read_rows", Quota::UNLIMITED);
limits.max[ResourceType::READ_BYTES] = config.getUInt64(interval_config + ".read_bytes", Quota::UNLIMITED);
limits.max[ResourceType::EXECUTION_TIME] = Quota::secondsToExecutionTime(config.getUInt64(interval_config + ".execution_time", Quota::UNLIMITED));
}
quota->roles = user_names;
return quota;
}
std::vector<AccessEntityPtr> parseQuotas(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log)
{
Poco::Util::AbstractConfiguration::Keys user_names;
config.keys("users", user_names);
std::unordered_map<String, Strings> quota_to_user_names;
for (const auto & user_name : user_names)
{
if (config.has("users." + user_name + ".quota"))
quota_to_user_names[config.getString("users." + user_name + ".quota")].push_back(user_name);
}
Poco::Util::AbstractConfiguration::Keys quota_names;
config.keys("quotas", quota_names);
std::vector<AccessEntityPtr> quotas;
quotas.reserve(quota_names.size());
for (const auto & quota_name : quota_names)
{
try
{
auto it = quota_to_user_names.find(quota_name);
const Strings quota_users = (it != quota_to_user_names.end()) ? std::move(it->second) : Strings{};
quotas.push_back(parseQuota(config, quota_name, quota_users));
}
catch (...)
{
tryLogCurrentException(log, "Could not parse quota " + backQuote(quota_name));
}
}
return quotas;
}
}
UsersConfigAccessStorage::UsersConfigAccessStorage() : IAccessStorage("users.xml")
{
}
UsersConfigAccessStorage::~UsersConfigAccessStorage() {}
void UsersConfigAccessStorage::loadFromConfig(const Poco::Util::AbstractConfiguration & config)
{
std::vector<std::pair<UUID, AccessEntityPtr>> all_entities;
for (const auto & entity : parseQuotas(config, getLogger()))
all_entities.emplace_back(generateID(*entity), entity);
memory_storage.setAll(all_entities);
}
std::optional<UUID> UsersConfigAccessStorage::findImpl(std::type_index type, const String & name) const
{
return memory_storage.find(type, name);
}
std::vector<UUID> UsersConfigAccessStorage::findAllImpl(std::type_index type) const
{
return memory_storage.findAll(type);
}
bool UsersConfigAccessStorage::existsImpl(const UUID & id) const
{
return memory_storage.exists(id);
}
AccessEntityPtr UsersConfigAccessStorage::readImpl(const UUID & id) const
{
return memory_storage.read(id);
}
String UsersConfigAccessStorage::readNameImpl(const UUID & id) const
{
return memory_storage.readName(id);
}
UUID UsersConfigAccessStorage::insertImpl(const AccessEntityPtr & entity, bool)
{
throwReadonlyCannotInsert(entity->getType(), entity->getFullName());
}
void UsersConfigAccessStorage::removeImpl(const UUID & id)
{
auto entity = read(id);
throwReadonlyCannotRemove(entity->getType(), entity->getFullName());
}
void UsersConfigAccessStorage::updateImpl(const UUID & id, const UpdateFunc &)
{
auto entity = read(id);
throwReadonlyCannotUpdate(entity->getType(), entity->getFullName());
}
IAccessStorage::SubscriptionPtr UsersConfigAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(id, handler);
}
IAccessStorage::SubscriptionPtr UsersConfigAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(type, handler);
}
bool UsersConfigAccessStorage::hasSubscriptionImpl(const UUID & id) const
{
return memory_storage.hasSubscription(id);
}
bool UsersConfigAccessStorage::hasSubscriptionImpl(std::type_index type) const
{
return memory_storage.hasSubscription(type);
}
}

View File

@ -0,0 +1,42 @@
#pragma once
#include <Access/MemoryAccessStorage.h>
namespace Poco
{
namespace Util
{
class AbstractConfiguration;
}
}
namespace DB
{
/// Implementation of IAccessStorage which loads all from users.xml periodically.
class UsersConfigAccessStorage : public IAccessStorage
{
public:
UsersConfigAccessStorage();
~UsersConfigAccessStorage() override;
void loadFromConfig(const Poco::Util::AbstractConfiguration & config);
private:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override;
bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID & id) const override;
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override;
MemoryAccessStorage memory_storage;
};
}

View File

@ -439,6 +439,10 @@ void Connection::sendQuery(
void Connection::sendCancel()
{
/// If we already disconnected.
if (!out)
return;
//LOG_TRACE(log_wrapper.get(), "Sending cancel");
writeVarUInt(Protocol::Client::Cancel, *out);

View File

@ -4,6 +4,9 @@
#if USE_ICU
#include <unicode/ucol.h>
#include <unicode/unistr.h>
#include <unicode/locid.h>
#include <unicode/ucnv.h>
#else
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunused-private-field"
@ -14,6 +17,7 @@
#include <Common/Exception.h>
#include <IO/WriteHelpers.h>
#include <Poco/String.h>
#include <algorithm>
namespace DB
@ -26,16 +30,81 @@ namespace DB
}
}
Collator::Collator(const std::string & locale_) : locale(Poco::toLower(locale_))
AvailableCollationLocales::AvailableCollationLocales()
{
#if USE_ICU
static const size_t MAX_LANG_LENGTH = 128;
size_t available_locales_count = ucol_countAvailable();
for (size_t i = 0; i < available_locales_count; ++i)
{
std::string locale_name = ucol_getAvailable(i);
UChar lang_buffer[MAX_LANG_LENGTH];
char normal_buf[MAX_LANG_LENGTH];
UErrorCode status = U_ZERO_ERROR;
/// All names will be in English language
size_t lang_length = uloc_getDisplayLanguage(
locale_name.c_str(), "en", lang_buffer, MAX_LANG_LENGTH, &status);
std::optional<std::string> lang;
if (!U_FAILURE(status))
{
/// Convert language name from UChar array to normal char array.
/// We use English language for name, so all UChar's length is equal to sizeof(char)
u_UCharsToChars(lang_buffer, normal_buf, lang_length);
lang.emplace(std::string(normal_buf, lang_length));
}
locales_map.emplace(Poco::toLower(locale_name), LocaleAndLanguage{locale_name, lang});
}
#endif
}
const AvailableCollationLocales & AvailableCollationLocales::instance()
{
static AvailableCollationLocales instance;
return instance;
}
AvailableCollationLocales::LocalesVector AvailableCollationLocales::getAvailableCollations() const
{
LocalesVector result;
for (const auto & name_and_locale : locales_map)
result.push_back(name_and_locale.second);
auto comparator = [] (const LocaleAndLanguage & f, const LocaleAndLanguage & s)
{
return f.locale_name < s.locale_name;
};
std::sort(result.begin(), result.end(), comparator);
return result;
}
bool AvailableCollationLocales::isCollationSupported(const std::string & locale_name) const
{
/// We support locale names in any case, so we have to convert all to lower case
return locales_map.count(Poco::toLower(locale_name));
}
Collator::Collator(const std::string & locale_)
: locale(Poco::toLower(locale_))
{
#if USE_ICU
/// We check it here, because ucol_open will fallback to default locale for
/// almost all random names.
if (!AvailableCollationLocales::instance().isCollationSupported(locale))
throw DB::Exception("Unsupported collation locale: " + locale, DB::ErrorCodes::UNSUPPORTED_COLLATION_LOCALE);
UErrorCode status = U_ZERO_ERROR;
collator = ucol_open(locale.c_str(), &status);
if (status != U_ZERO_ERROR)
if (U_FAILURE(status))
{
ucol_close(collator);
throw DB::Exception("Unsupported collation locale: " + locale, DB::ErrorCodes::UNSUPPORTED_COLLATION_LOCALE);
throw DB::Exception("Failed to open locale: " + locale + " with error: " + u_errorName(status), DB::ErrorCodes::UNSUPPORTED_COLLATION_LOCALE);
}
#else
throw DB::Exception("Collations support is disabled, because ClickHouse was built without ICU library", DB::ErrorCodes::SUPPORT_IS_DISABLED);
@ -60,8 +129,8 @@ int Collator::compare(const char * str1, size_t length1, const char * str2, size
UErrorCode status = U_ZERO_ERROR;
UCollationResult compare_result = ucol_strcollIter(collator, &iter1, &iter2, &status);
if (status != U_ZERO_ERROR)
throw DB::Exception("ICU collation comparison failed with error code: " + DB::toString<int>(status),
if (U_FAILURE(status))
throw DB::Exception("ICU collation comparison failed with error code: " + std::string(u_errorName(status)),
DB::ErrorCodes::COLLATION_COMPARISON_FAILED);
/** Values of enum UCollationResult are equals to what exactly we need:
@ -83,14 +152,3 @@ const std::string & Collator::getLocale() const
{
return locale;
}
std::vector<std::string> Collator::getAvailableCollations()
{
std::vector<std::string> result;
#if USE_ICU
size_t available_locales_count = ucol_countAvailable();
for (size_t i = 0; i < available_locales_count; ++i)
result.push_back(ucol_getAvailable(i));
#endif
return result;
}

View File

@ -3,9 +3,39 @@
#include <string>
#include <vector>
#include <boost/noncopyable.hpp>
#include <unordered_map>
struct UCollator;
/// Class represents available locales for collations.
class AvailableCollationLocales : private boost::noncopyable
{
public:
struct LocaleAndLanguage
{
std::string locale_name; /// ISO locale code
std::optional<std::string> language; /// full language name in English
};
using AvailableLocalesMap = std::unordered_map<std::string, LocaleAndLanguage>;
using LocalesVector = std::vector<LocaleAndLanguage>;
static const AvailableCollationLocales & instance();
/// Get all collations with names in sorted order
LocalesVector getAvailableCollations() const;
/// Check that collation is supported
bool isCollationSupported(const std::string & locale_name) const;
private:
AvailableCollationLocales();
private:
AvailableLocalesMap locales_map;
};
class Collator : private boost::noncopyable
{
public:
@ -15,10 +45,8 @@ public:
int compare(const char * str1, size_t length1, const char * str2, size_t length2) const;
const std::string & getLocale() const;
static std::vector<std::string> getAvailableCollations();
private:
std::string locale;
UCollator * collator;
};

View File

@ -105,6 +105,11 @@ public:
return data->getFloat64(0);
}
Float32 getFloat32(size_t) const override
{
return data->getFloat32(0);
}
bool isNullAt(size_t) const override
{
return data->isNullAt(0);
@ -219,6 +224,7 @@ public:
Field getField() const { return getDataColumn()[0]; }
/// The constant value. It is valid even if the size of the column is 0.
template <typename T>
T getValue() const { return getField().safeGet<NearestFieldType<T>>(); }
};

View File

@ -96,6 +96,7 @@ public:
void insertFrom(const IColumn & src, size_t n) override { data.push_back(static_cast<const Self &>(src).getData()[n]); }
void insertData(const char * pos, size_t /*length*/) override;
void insertDefault() override { data.push_back(T()); }
virtual void insertManyDefaults(size_t length) override { data.resize_fill(data.size() + length); }
void insert(const Field & x) override { data.push_back(DB::get<NearestFieldType<T>>(x)); }
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
@ -144,7 +145,7 @@ public:
}
void insert(const T value) { data.push_back(value); }
void insertValue(const T value) { data.push_back(value); }
Container & getData() { return data; }
const Container & getData() const { return data; }
const T & getElement(size_t n) const { return data[n]; }

View File

@ -92,6 +92,11 @@ public:
chars.resize_fill(chars.size() + n);
}
virtual void insertManyDefaults(size_t length) override
{
chars.resize_fill(chars.size() + n * length);
}
void popBack(size_t elems) override
{
chars.resize_assume_reserved(chars.size() - n * elems);

View File

@ -59,6 +59,7 @@ public:
UInt64 getUInt(size_t n) const override { return getDictionary().getUInt(getIndexes().getUInt(n)); }
Int64 getInt(size_t n) const override { return getDictionary().getInt(getIndexes().getUInt(n)); }
Float64 getFloat64(size_t n) const override { return getDictionary().getInt(getIndexes().getFloat64(n)); }
Float32 getFloat32(size_t n) const override { return getDictionary().getInt(getIndexes().getFloat32(n)); }
bool getBool(size_t n) const override { return getDictionary().getInt(getIndexes().getBool(n)); }
bool isNullAt(size_t n) const override { return getDictionary().isNullAt(getIndexes().getUInt(n)); }
ColumnPtr cut(size_t start, size_t length) const override

View File

@ -205,6 +205,13 @@ public:
offsets.push_back(offsets.back() + 1);
}
virtual void insertManyDefaults(size_t length) override
{
chars.resize_fill(chars.size() + length);
for (size_t i = 0; i < length; ++i)
offsets.push_back(offsets.back() + 1);
}
int compareAt(size_t n, size_t m, const IColumn & rhs_, int /*nan_direction_hint*/) const override
{
const ColumnString & rhs = assert_cast<const ColumnString &>(rhs_);

View File

@ -66,6 +66,7 @@ public:
UInt64 getUInt(size_t n) const override { return getNestedColumn()->getUInt(n); }
Int64 getInt(size_t n) const override { return getNestedColumn()->getInt(n); }
Float64 getFloat64(size_t n) const override { return getNestedColumn()->getFloat64(n); }
Float32 getFloat32(size_t n) const override { return getNestedColumn()->getFloat32(n); }
bool getBool(size_t n) const override { return getNestedColumn()->getBool(n); }
bool isNullAt(size_t n) const override { return is_nullable && n == getNullValueIndex(); }
StringRef serializeValueIntoArena(size_t n, Arena & arena, char const *& begin) const override;

View File

@ -222,6 +222,12 @@ Float64 ColumnVector<T>::getFloat64(size_t n) const
return static_cast<Float64>(data[n]);
}
template <typename T>
Float32 ColumnVector<T>::getFloat32(size_t n) const
{
return static_cast<Float32>(data[n]);
}
template <typename T>
void ColumnVector<T>::insertRangeFrom(const IColumn & src, size_t start, size_t length)
{

View File

@ -144,6 +144,11 @@ public:
data.push_back(T());
}
virtual void insertManyDefaults(size_t length) override
{
data.resize_fill(data.size() + length, T());
}
void popBack(size_t n) override
{
data.resize_assume_reserved(data.size() - n);
@ -205,6 +210,7 @@ public:
UInt64 get64(size_t n) const override;
Float64 getFloat64(size_t n) const override;
Float32 getFloat32(size_t n) const override;
UInt64 getUInt(size_t n) const override
{

View File

@ -100,6 +100,11 @@ public:
throw Exception("Method getFloat64 is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
virtual Float32 getFloat32(size_t /*n*/) const
{
throw Exception("Method getFloat32 is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** If column is numeric, return value of n-th element, casted to UInt64.
* For NULL values of Nullable column it is allowed to return arbitrary value.
* Otherwise throw an exception.

View File

@ -34,6 +34,7 @@ namespace ErrorCodes
extern const int CANNOT_STATVFS;
extern const int NOT_ENOUGH_SPACE;
extern const int NOT_IMPLEMENTED;
extern const int NO_SUCH_DATA_PART;
extern const int SYSTEM_ERROR;
extern const int UNKNOWN_ELEMENT_IN_CONFIG;
extern const int EXCESSIVE_ELEMENT_IN_CONFIG;

View File

@ -464,6 +464,15 @@ namespace ErrorCodes
extern const int CANNOT_GET_CREATE_DICTIONARY_QUERY = 487;
extern const int UNKNOWN_DICTIONARY = 488;
extern const int INCORRECT_DICTIONARY_DEFINITION = 489;
extern const int CANNOT_FORMAT_DATETIME = 490;
extern const int UNACCEPTABLE_URL = 491;
extern const int ACCESS_ENTITY_NOT_FOUND = 492;
extern const int ACCESS_ENTITY_ALREADY_EXISTS = 493;
extern const int ACCESS_ENTITY_FOUND_DUPLICATES = 494;
extern const int ACCESS_ENTITY_STORAGE_READONLY = 495;
extern const int QUOTA_REQUIRES_CLIENT_KEY = 496;
extern const int NOT_ENOUGH_PRIVILEGES = 497;
extern const int LIMIT_BY_WITH_TIES_IS_NOT_SUPPORTED = 498;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -34,97 +34,23 @@ struct StaticVisitor
/// F is template parameter, to allow universal reference for field, that is useful for const and non-const values.
template <typename Visitor, typename F>
typename std::decay_t<Visitor>::ResultType applyVisitor(Visitor && visitor, F && field)
auto applyVisitor(Visitor && visitor, F && field)
{
switch (field.getType())
{
case Field::Types::Null: return visitor(field.template get<Null>());
case Field::Types::UInt64: return visitor(field.template get<UInt64>());
case Field::Types::UInt128: return visitor(field.template get<UInt128>());
case Field::Types::Int64: return visitor(field.template get<Int64>());
case Field::Types::Float64: return visitor(field.template get<Float64>());
case Field::Types::String: return visitor(field.template get<String>());
case Field::Types::Array: return visitor(field.template get<Array>());
case Field::Types::Tuple: return visitor(field.template get<Tuple>());
case Field::Types::Decimal32: return visitor(field.template get<DecimalField<Decimal32>>());
case Field::Types::Decimal64: return visitor(field.template get<DecimalField<Decimal64>>());
case Field::Types::Decimal128: return visitor(field.template get<DecimalField<Decimal128>>());
case Field::Types::AggregateFunctionState: return visitor(field.template get<AggregateFunctionStateData>());
default:
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
}
}
template <typename Visitor, typename F1, typename F2>
static typename std::decay_t<Visitor>::ResultType applyBinaryVisitorImpl(Visitor && visitor, F1 && field1, F2 && field2)
{
switch (field2.getType())
{
case Field::Types::Null: return visitor(field1, field2.template get<Null>());
case Field::Types::UInt64: return visitor(field1, field2.template get<UInt64>());
case Field::Types::UInt128: return visitor(field1, field2.template get<UInt128>());
case Field::Types::Int64: return visitor(field1, field2.template get<Int64>());
case Field::Types::Float64: return visitor(field1, field2.template get<Float64>());
case Field::Types::String: return visitor(field1, field2.template get<String>());
case Field::Types::Array: return visitor(field1, field2.template get<Array>());
case Field::Types::Tuple: return visitor(field1, field2.template get<Tuple>());
case Field::Types::Decimal32: return visitor(field1, field2.template get<DecimalField<Decimal32>>());
case Field::Types::Decimal64: return visitor(field1, field2.template get<DecimalField<Decimal64>>());
case Field::Types::Decimal128: return visitor(field1, field2.template get<DecimalField<Decimal128>>());
case Field::Types::AggregateFunctionState: return visitor(field1, field2.template get<AggregateFunctionStateData>());
default:
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
}
return Field::dispatch(visitor, field);
}
template <typename Visitor, typename F1, typename F2>
typename std::decay_t<Visitor>::ResultType applyVisitor(Visitor && visitor, F1 && field1, F2 && field2)
auto applyVisitor(Visitor && visitor, F1 && field1, F2 && field2)
{
switch (field1.getType())
{
case Field::Types::Null:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<Null>(), std::forward<F2>(field2));
case Field::Types::UInt64:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<UInt64>(), std::forward<F2>(field2));
case Field::Types::UInt128:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<UInt128>(), std::forward<F2>(field2));
case Field::Types::Int64:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<Int64>(), std::forward<F2>(field2));
case Field::Types::Float64:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<Float64>(), std::forward<F2>(field2));
case Field::Types::String:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<String>(), std::forward<F2>(field2));
case Field::Types::Array:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<Array>(), std::forward<F2>(field2));
case Field::Types::Tuple:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<Tuple>(), std::forward<F2>(field2));
case Field::Types::Decimal32:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<DecimalField<Decimal32>>(), std::forward<F2>(field2));
case Field::Types::Decimal64:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<DecimalField<Decimal64>>(), std::forward<F2>(field2));
case Field::Types::Decimal128:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<DecimalField<Decimal128>>(), std::forward<F2>(field2));
case Field::Types::AggregateFunctionState:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<AggregateFunctionStateData>(), std::forward<F2>(field2));
default:
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
}
return Field::dispatch([&](auto & field1_value)
{
return Field::dispatch([&](auto & field2_value)
{
return visitor(field1_value, field2_value);
},
field2);
},
field1);
}
@ -473,8 +399,14 @@ private:
public:
explicit FieldVisitorSum(const Field & rhs_) : rhs(rhs_) {}
bool operator() (UInt64 & x) const { x += get<UInt64>(rhs); return x != 0; }
bool operator() (Int64 & x) const { x += get<Int64>(rhs); return x != 0; }
// We can add all ints as unsigned regardless of their actual signedness.
bool operator() (Int64 & x) const { return this->operator()(reinterpret_cast<UInt64 &>(x)); }
bool operator() (UInt64 & x) const
{
x += rhs.reinterpret<UInt64>();
return x != 0;
}
bool operator() (Float64 & x) const { x += get<Float64>(rhs); return x != 0; }
bool operator() (Null &) const { throw Exception("Cannot sum Nulls", ErrorCodes::LOGICAL_ERROR); }

View File

@ -84,6 +84,23 @@ struct DefaultHash<T, std::enable_if_t<is_arithmetic_v<T>>>
}
};
template <typename T>
struct DefaultHash<T, std::enable_if_t<DB::IsDecimalNumber<T> && sizeof(T) <= 8>>
{
size_t operator() (T key) const
{
return DefaultHash64<typename T::NativeType>(key);
}
};
template <typename T>
struct DefaultHash<T, std::enable_if_t<DB::IsDecimalNumber<T> && sizeof(T) == 16>>
{
size_t operator() (T key) const
{
return DefaultHash64<Int64>(key >> 64) ^ DefaultHash64<Int64>(key);
}
};
template <typename T> struct HashCRC32;

View File

@ -0,0 +1,162 @@
#include <Common/IntervalKind.h>
#include <Common/Exception.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SYNTAX_ERROR;
}
const char * IntervalKind::toString() const
{
switch (kind)
{
case IntervalKind::Second: return "Second";
case IntervalKind::Minute: return "Minute";
case IntervalKind::Hour: return "Hour";
case IntervalKind::Day: return "Day";
case IntervalKind::Week: return "Week";
case IntervalKind::Month: return "Month";
case IntervalKind::Quarter: return "Quarter";
case IntervalKind::Year: return "Year";
}
__builtin_unreachable();
}
Int32 IntervalKind::toAvgSeconds() const
{
switch (kind)
{
case IntervalKind::Second: return 1;
case IntervalKind::Minute: return 60;
case IntervalKind::Hour: return 3600;
case IntervalKind::Day: return 86400;
case IntervalKind::Week: return 604800;
case IntervalKind::Month: return 2629746; /// Exactly 1/12 of a year.
case IntervalKind::Quarter: return 7889238; /// Exactly 1/4 of a year.
case IntervalKind::Year: return 31556952; /// The average length of a Gregorian year is equal to 365.2425 days
}
__builtin_unreachable();
}
IntervalKind IntervalKind::fromAvgSeconds(Int64 num_seconds)
{
if (num_seconds)
{
if (!(num_seconds % 31556952))
return IntervalKind::Year;
if (!(num_seconds % 7889238))
return IntervalKind::Quarter;
if (!(num_seconds % 604800))
return IntervalKind::Week;
if (!(num_seconds % 2629746))
return IntervalKind::Month;
if (!(num_seconds % 86400))
return IntervalKind::Day;
if (!(num_seconds % 3600))
return IntervalKind::Hour;
if (!(num_seconds % 60))
return IntervalKind::Minute;
}
return IntervalKind::Second;
}
const char * IntervalKind::toKeyword() const
{
switch (kind)
{
case IntervalKind::Second: return "SECOND";
case IntervalKind::Minute: return "MINUTE";
case IntervalKind::Hour: return "HOUR";
case IntervalKind::Day: return "DAY";
case IntervalKind::Week: return "WEEK";
case IntervalKind::Month: return "MONTH";
case IntervalKind::Quarter: return "QUARTER";
case IntervalKind::Year: return "YEAR";
}
__builtin_unreachable();
}
const char * IntervalKind::toDateDiffUnit() const
{
switch (kind)
{
case IntervalKind::Second:
return "second";
case IntervalKind::Minute:
return "minute";
case IntervalKind::Hour:
return "hour";
case IntervalKind::Day:
return "day";
case IntervalKind::Week:
return "week";
case IntervalKind::Month:
return "month";
case IntervalKind::Quarter:
return "quarter";
case IntervalKind::Year:
return "year";
}
__builtin_unreachable();
}
const char * IntervalKind::toNameOfFunctionToIntervalDataType() const
{
switch (kind)
{
case IntervalKind::Second:
return "toIntervalSecond";
case IntervalKind::Minute:
return "toIntervalMinute";
case IntervalKind::Hour:
return "toIntervalHour";
case IntervalKind::Day:
return "toIntervalDay";
case IntervalKind::Week:
return "toIntervalWeek";
case IntervalKind::Month:
return "toIntervalMonth";
case IntervalKind::Quarter:
return "toIntervalQuarter";
case IntervalKind::Year:
return "toIntervalYear";
}
__builtin_unreachable();
}
const char * IntervalKind::toNameOfFunctionExtractTimePart() const
{
switch (kind)
{
case IntervalKind::Second:
return "toSecond";
case IntervalKind::Minute:
return "toMinute";
case IntervalKind::Hour:
return "toHour";
case IntervalKind::Day:
return "toDayOfMonth";
case IntervalKind::Week:
// TODO: SELECT toRelativeWeekNum(toDate('2017-06-15')) - toRelativeWeekNum(toStartOfYear(toDate('2017-06-15')))
// else if (ParserKeyword("WEEK").ignore(pos, expected))
// function_name = "toRelativeWeekNum";
throw Exception("The syntax 'EXTRACT(WEEK FROM date)' is not supported, cannot extract the number of a week", ErrorCodes::SYNTAX_ERROR);
case IntervalKind::Month:
return "toMonth";
case IntervalKind::Quarter:
return "toQuarter";
case IntervalKind::Year:
return "toYear";
}
__builtin_unreachable();
}
}

View File

@ -0,0 +1,54 @@
#pragma once
#include <Core/Types.h>
namespace DB
{
/// Kind of a temporal interval.
struct IntervalKind
{
enum Kind
{
Second,
Minute,
Hour,
Day,
Week,
Month,
Quarter,
Year,
};
Kind kind = Second;
IntervalKind(Kind kind_ = Second) : kind(kind_) {}
operator Kind() const { return kind; }
const char * toString() const;
/// Returns number of seconds in one interval.
/// For `Month`, `Quarter` and `Year` the function returns an average number of seconds.
Int32 toAvgSeconds() const;
/// Chooses an interval kind based on number of seconds.
/// For example, `IntervalKind::fromAvgSeconds(3600)` returns `IntervalKind::Hour`.
static IntervalKind fromAvgSeconds(Int64 num_seconds);
/// Returns an uppercased version of what `toString()` returns.
const char * toKeyword() const;
/// Returns the string which can be passed to the `unit` parameter of the dateDiff() function.
/// For example, `IntervalKind{IntervalKind::Day}.getDateDiffParameter()` returns "day".
const char * toDateDiffUnit() const;
/// Returns the name of the function converting a number to the interval data type.
/// For example, `IntervalKind{IntervalKind::Day}.getToIntervalDataTypeFunctionName()`
/// returns "toIntervalDay".
const char * toNameOfFunctionToIntervalDataType() const;
/// Returns the name of the function extracting time part from a date or a time.
/// For example, `IntervalKind{IntervalKind::Day}.getExtractTimePartFunctionName()`
/// returns "toDayOfMonth".
const char * toNameOfFunctionExtractTimePart() const;
};
}

View File

@ -0,0 +1,62 @@
#include <re2/re2.h>
#include <Common/RemoteHostFilter.h>
#include <Poco/URI.h>
#include <Formats/FormatFactory.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/Exception.h>
#include <IO/WriteHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNACCEPTABLE_URL;
}
void RemoteHostFilter::checkURL(const Poco::URI & uri) const
{
if (!checkForDirectEntry(uri.getHost()) &&
!checkForDirectEntry(uri.getHost() + ":" + toString(uri.getPort())))
throw Exception("URL \"" + uri.toString() + "\" is not allowed in config.xml", ErrorCodes::UNACCEPTABLE_URL);
}
void RemoteHostFilter::checkHostAndPort(const std::string & host, const std::string & port) const
{
if (!checkForDirectEntry(host) &&
!checkForDirectEntry(host + ":" + port))
throw Exception("URL \"" + host + ":" + port + "\" is not allowed in config.xml", ErrorCodes::UNACCEPTABLE_URL);
}
void RemoteHostFilter::setValuesFromConfig(const Poco::Util::AbstractConfiguration & config)
{
if (config.has("remote_url_allow_hosts"))
{
std::vector<std::string> keys;
config.keys("remote_url_allow_hosts", keys);
for (auto key : keys)
{
if (startsWith(key, "host_regexp"))
regexp_hosts.push_back(config.getString("remote_url_allow_hosts." + key));
else if (startsWith(key, "host"))
primary_hosts.insert(config.getString("remote_url_allow_hosts." + key));
}
}
}
bool RemoteHostFilter::checkForDirectEntry(const std::string & str) const
{
if (!primary_hosts.empty() || !regexp_hosts.empty())
{
if (primary_hosts.find(str) == primary_hosts.end())
{
for (size_t i = 0; i < regexp_hosts.size(); ++i)
if (re2::RE2::FullMatch(str, regexp_hosts[i]))
return true;
return false;
}
return true;
}
return true;
}
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <vector>
#include <unordered_set>
#include <Poco/URI.h>
#include <Poco/Util/AbstractConfiguration.h>
namespace DB
{
class RemoteHostFilter
{
/**
* This class checks if url is allowed.
* If primary_hosts and regexp_hosts are empty all urls are allowed.
*/
public:
void checkURL(const Poco::URI & uri) const; /// If URL not allowed in config.xml throw UNACCEPTABLE_URL Exception
void setValuesFromConfig(const Poco::Util::AbstractConfiguration & config);
void checkHostAndPort(const std::string & host, const std::string & port) const; /// Does the same as checkURL, but for host and port.
private:
std::unordered_set<std::string> primary_hosts; /// Allowed primary (<host>) URL from config.xml
std::vector<std::string> regexp_hosts; /// Allowed regexp (<hots_regexp>) URL from config.xml
bool checkForDirectEntry(const std::string & str) const; /// Checks if the primary_hosts and regexp_hosts contain str. If primary_hosts and regexp_hosts are empty return true.
};
}

View File

@ -158,7 +158,7 @@ std::string signalToErrorMessage(int sig, const siginfo_t & info, const ucontext
break;
}
case SIGPROF:
case SIGTSTP:
{
error << "This is a signal used for debugging purposes by the user.";
break;

View File

@ -3,11 +3,18 @@
#include <cstdint>
#include <limits>
#include <Core/Defines.h>
// Also defined in Core/Defines.h
#if !defined(NO_SANITIZE_UNDEFINED)
#if defined(__clang__)
#define NO_SANITIZE_UNDEFINED __attribute__((__no_sanitize__("undefined")))
#else
#define NO_SANITIZE_UNDEFINED
#endif
#endif
/// On overlow, the function returns unspecified value.
inline NO_SANITIZE_UNDEFINED uint64_t intExp2(int x)
{
return 1ULL << x;

View File

@ -0,0 +1,40 @@
#if defined(OS_LINUX)
#include <stdlib.h>
/// Interposing these symbols explicitly. The idea works like this: malloc.cpp compiles to a
/// dedicated object (namely clickhouse_malloc.o), and it will show earlier in the link command
/// than malloc libs like libjemalloc.a. As a result, these symbols get picked in time right after.
extern "C"
{
void *malloc(size_t size);
void free(void *ptr);
void *calloc(size_t nmemb, size_t size);
void *realloc(void *ptr, size_t size);
int posix_memalign(void **memptr, size_t alignment, size_t size);
void *aligned_alloc(size_t alignment, size_t size);
void *valloc(size_t size);
void *memalign(size_t alignment, size_t size);
void *pvalloc(size_t size);
}
template<typename T>
inline void ignore(T x __attribute__((unused)))
{
}
static void dummyFunctionForInterposing() __attribute__((used));
static void dummyFunctionForInterposing()
{
void* dummy;
/// Suppression for PVS-Studio.
free(nullptr); // -V575
ignore(malloc(0)); // -V575
ignore(calloc(0, 0)); // -V575
ignore(realloc(nullptr, 0)); // -V575
ignore(posix_memalign(&dummy, 0, 0)); // -V575
ignore(aligned_alloc(0, 0)); // -V575
ignore(valloc(0)); // -V575
ignore(memalign(0, 0)); // -V575
ignore(pvalloc(0)); // -V575
}
#endif

View File

@ -3,8 +3,10 @@
#include <type_traits>
#include <typeinfo>
#include <typeindex>
#include <memory>
#include <string>
#include <ext/shared_ptr_helper.h>
#include <Common/Exception.h>
#include <common/demangle.h>
@ -27,7 +29,7 @@ std::enable_if_t<std::is_reference_v<To>, To> typeid_cast(From & from)
{
try
{
if (typeid(from) == typeid(To))
if ((typeid(From) == typeid(To)) || (typeid(from) == typeid(To)))
return static_cast<To>(from);
}
catch (const std::exception & e)
@ -39,12 +41,13 @@ std::enable_if_t<std::is_reference_v<To>, To> typeid_cast(From & from)
DB::ErrorCodes::BAD_CAST);
}
template <typename To, typename From>
To typeid_cast(From * from)
std::enable_if_t<std::is_pointer_v<To>, To> typeid_cast(From * from)
{
try
{
if (typeid(*from) == typeid(std::remove_pointer_t<To>))
if ((typeid(From) == typeid(std::remove_pointer_t<To>)) || (typeid(*from) == typeid(std::remove_pointer_t<To>)))
return static_cast<To>(from);
else
return nullptr;
@ -54,3 +57,20 @@ To typeid_cast(From * from)
throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST);
}
}
template <typename To, typename From>
std::enable_if_t<ext::is_shared_ptr_v<To>, To> typeid_cast(const std::shared_ptr<From> & from)
{
try
{
if ((typeid(From) == typeid(typename To::element_type)) || (typeid(*from) == typeid(typename To::element_type)))
return std::static_pointer_cast<typename To::element_type>(from);
else
return nullptr;
}
catch (const std::exception & e)
{
throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST);
}
}

View File

@ -88,9 +88,9 @@ public:
Shift shift;
if (scale_a < scale_b)
shift.a = DataTypeDecimal<B>(maxDecimalPrecision<B>(), scale_b).getScaleMultiplier(scale_b - scale_a);
shift.a = B::getScaleMultiplier(scale_b - scale_a);
if (scale_a > scale_b)
shift.b = DataTypeDecimal<A>(maxDecimalPrecision<A>(), scale_a).getScaleMultiplier(scale_a - scale_b);
shift.b = A::getScaleMultiplier(scale_a - scale_b);
return applyWithScale(a, b, shift);
}

View File

@ -151,8 +151,8 @@
#endif
/// Marks that extra information is sent to a shard. It could be any magic numbers.
#define DBMS_DISTRIBUTED_SIGNATURE_EXTRA_INFO 0xCAFEDACEull
#define DBMS_DISTRIBUTED_SIGNATURE_SETTINGS_OLD_FORMAT 0xCAFECABEull
#define DBMS_DISTRIBUTED_SIGNATURE_HEADER 0xCAFEDACEull
#define DBMS_DISTRIBUTED_SIGNATURE_HEADER_OLD_FORMAT 0xCAFECABEull
#if !__has_include(<sanitizer/asan_interface.h>)
# define ASAN_UNPOISON_MEMORY_REGION(a, b)

View File

@ -295,26 +295,11 @@ namespace DB
void writeFieldText(const Field & x, WriteBuffer & buf)
{
DB::String res = applyVisitor(DB::FieldVisitorToString(), x);
DB::String res = Field::dispatch(DB::FieldVisitorToString(), x);
buf.write(res.data(), res.size());
}
template <> Decimal32 DecimalField<Decimal32>::getScaleMultiplier() const
{
return DataTypeDecimal<Decimal32>::getScaleMultiplier(scale);
}
template <> Decimal64 DecimalField<Decimal64>::getScaleMultiplier() const
{
return DataTypeDecimal<Decimal64>::getScaleMultiplier(scale);
}
template <> Decimal128 DecimalField<Decimal128>::getScaleMultiplier() const
{
return DataTypeDecimal<Decimal128>::getScaleMultiplier(scale);
}
template <typename T>
static bool decEqual(T x, T y, UInt32 x_scale, UInt32 y_scale)
{

View File

@ -27,7 +27,7 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
template <typename T>
template <typename T, typename SFINAE = void>
struct NearestFieldTypeImpl;
template <typename T>
@ -102,7 +102,7 @@ public:
operator T() const { return dec; }
T getValue() const { return dec; }
T getScaleMultiplier() const;
T getScaleMultiplier() const { return T::getScaleMultiplier(scale); }
UInt32 getScale() const { return scale; }
template <typename U>
@ -151,6 +151,54 @@ private:
UInt32 scale;
};
/// char may be signed or unsigned, and behave identically to signed char or unsigned char,
/// but they are always three different types.
/// signedness of char is different in Linux on x86 and Linux on ARM.
template <> struct NearestFieldTypeImpl<char> { using Type = std::conditional_t<is_signed_v<char>, Int64, UInt64>; };
template <> struct NearestFieldTypeImpl<signed char> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<unsigned char> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<UInt16> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<UInt32> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<DayNum> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<UInt128> { using Type = UInt128; };
template <> struct NearestFieldTypeImpl<UUID> { using Type = UInt128; };
template <> struct NearestFieldTypeImpl<Int16> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<Int32> { using Type = Int64; };
/// long and long long are always different types that may behave identically or not.
/// This is different on Linux and Mac.
template <> struct NearestFieldTypeImpl<long> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<long long> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<unsigned long> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<unsigned long long> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<Int128> { using Type = Int128; };
template <> struct NearestFieldTypeImpl<Decimal32> { using Type = DecimalField<Decimal32>; };
template <> struct NearestFieldTypeImpl<Decimal64> { using Type = DecimalField<Decimal64>; };
template <> struct NearestFieldTypeImpl<Decimal128> { using Type = DecimalField<Decimal128>; };
template <> struct NearestFieldTypeImpl<DecimalField<Decimal32>> { using Type = DecimalField<Decimal32>; };
template <> struct NearestFieldTypeImpl<DecimalField<Decimal64>> { using Type = DecimalField<Decimal64>; };
template <> struct NearestFieldTypeImpl<DecimalField<Decimal128>> { using Type = DecimalField<Decimal128>; };
template <> struct NearestFieldTypeImpl<Float32> { using Type = Float64; };
template <> struct NearestFieldTypeImpl<Float64> { using Type = Float64; };
template <> struct NearestFieldTypeImpl<const char *> { using Type = String; };
template <> struct NearestFieldTypeImpl<String> { using Type = String; };
template <> struct NearestFieldTypeImpl<Array> { using Type = Array; };
template <> struct NearestFieldTypeImpl<Tuple> { using Type = Tuple; };
template <> struct NearestFieldTypeImpl<bool> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<Null> { using Type = Null; };
template <> struct NearestFieldTypeImpl<AggregateFunctionStateData> { using Type = AggregateFunctionStateData; };
// For enum types, use the field type that corresponds to their underlying type.
template <typename T>
struct NearestFieldTypeImpl<T, std::enable_if_t<std::is_enum_v<T>>>
{
using Type = NearestFieldType<std::underlying_type_t<T>>;
};
/** 32 is enough. Round number is used for alignment and for better arithmetic inside std::vector.
* NOTE: Actually, sizeof(std::string) is 32 when using libc++, so Field is 40 bytes.
*/
@ -314,18 +362,24 @@ public:
bool isNull() const { return which == Types::Null; }
template <typename T> T & get()
template <typename T>
T & get();
template <typename T>
const T & get() const
{
using TWithoutRef = std::remove_reference_t<T>;
TWithoutRef * MAY_ALIAS ptr = reinterpret_cast<TWithoutRef*>(&storage);
return *ptr;
auto mutable_this = const_cast<std::decay_t<decltype(*this)> *>(this);
return mutable_this->get<T>();
}
template <typename T> const T & get() const
template <typename T>
T & reinterpret();
template <typename T>
const T & reinterpret() const
{
using TWithoutRef = std::remove_reference_t<T>;
const TWithoutRef * MAY_ALIAS ptr = reinterpret_cast<const TWithoutRef*>(&storage);
return *ptr;
auto mutable_this = const_cast<std::decay_t<decltype(*this)> *>(this);
return mutable_this->reinterpret<T>();
}
template <typename T> bool tryGet(T & result)
@ -427,6 +481,8 @@ public:
return rhs <= *this;
}
// More like bitwise equality as opposed to semantic equality:
// Null equals Null and NaN equals NaN.
bool operator== (const Field & rhs) const
{
if (which != rhs.which)
@ -435,9 +491,13 @@ public:
switch (which)
{
case Types::Null: return true;
case Types::UInt64:
case Types::Int64:
case Types::Float64: return get<UInt64>() == rhs.get<UInt64>();
case Types::UInt64: return get<UInt64>() == rhs.get<UInt64>();
case Types::Int64: return get<Int64>() == rhs.get<Int64>();
case Types::Float64:
{
// Compare as UInt64 so that NaNs compare as equal.
return reinterpret<UInt64>() == rhs.reinterpret<UInt64>();
}
case Types::String: return get<String>() == rhs.get<String>();
case Types::Array: return get<Array>() == rhs.get<Array>();
case Types::Tuple: return get<Tuple>() == rhs.get<Tuple>();
@ -457,6 +517,42 @@ public:
return !(*this == rhs);
}
/// Field is template parameter, to allow universal reference for field,
/// that is useful for const and non-const .
template <typename F, typename FieldRef>
static auto dispatch(F && f, FieldRef && field)
{
switch (field.which)
{
case Types::Null: return f(field.template get<Null>());
case Types::UInt64: return f(field.template get<UInt64>());
case Types::UInt128: return f(field.template get<UInt128>());
case Types::Int64: return f(field.template get<Int64>());
case Types::Float64: return f(field.template get<Float64>());
case Types::String: return f(field.template get<String>());
case Types::Array: return f(field.template get<Array>());
case Types::Tuple: return f(field.template get<Tuple>());
case Types::Decimal32: return f(field.template get<DecimalField<Decimal32>>());
case Types::Decimal64: return f(field.template get<DecimalField<Decimal64>>());
case Types::Decimal128: return f(field.template get<DecimalField<Decimal128>>());
case Types::AggregateFunctionState: return f(field.template get<AggregateFunctionStateData>());
case Types::Int128:
// TODO: investigate where we need Int128 Fields. There are no
// field visitors that support them, and they only arise indirectly
// in some functions that use Decimal columns: they get the
// underlying Field value with get<Int128>(). Probably should be
// switched to DecimalField, but this is a whole endeavor in itself.
throw Exception("Unexpected Int128 in Field::dispatch()", ErrorCodes::LOGICAL_ERROR);
}
// GCC 9 complains that control reaches the end, despite that we handle
// all the cases above (maybe because of throw?). Return something to
// silence it.
Null null{};
return f(null);
}
private:
std::aligned_union_t<DBMS_MIN_FIELD_SIZE - sizeof(Types::Which),
Null, UInt64, UInt128, Int64, Int128, Float64, String, Array, Tuple,
@ -493,37 +589,6 @@ private:
}
template <typename F, typename Field> /// Field template parameter may be const or non-const Field.
static void dispatch(F && f, Field & field)
{
switch (field.which)
{
case Types::Null: f(field.template get<Null>()); return;
// gcc 7.3.0
#if !__clang__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
case Types::UInt64: f(field.template get<UInt64>()); return;
case Types::UInt128: f(field.template get<UInt128>()); return;
case Types::Int64: f(field.template get<Int64>()); return;
case Types::Int128: f(field.template get<Int128>()); return;
case Types::Float64: f(field.template get<Float64>()); return;
#if !__clang__
#pragma GCC diagnostic pop
#endif
case Types::String: f(field.template get<String>()); return;
case Types::Array: f(field.template get<Array>()); return;
case Types::Tuple: f(field.template get<Tuple>()); return;
case Types::Decimal32: f(field.template get<DecimalField<Decimal32>>()); return;
case Types::Decimal64: f(field.template get<DecimalField<Decimal64>>()); return;
case Types::Decimal128: f(field.template get<DecimalField<Decimal128>>()); return;
case Types::AggregateFunctionState: f(field.template get<AggregateFunctionStateData>()); return;
}
}
void create(const Field & x)
{
dispatch([this] (auto & value) { createConcrete(value); }, x);
@ -621,6 +686,22 @@ template <> struct Field::EnumToType<Field::Types::Decimal64> { using Type = Dec
template <> struct Field::EnumToType<Field::Types::Decimal128> { using Type = DecimalField<Decimal128>; };
template <> struct Field::EnumToType<Field::Types::AggregateFunctionState> { using Type = DecimalField<AggregateFunctionStateData>; };
template <typename T>
T & Field::get()
{
using ValueType = std::decay_t<T>;
//assert(TypeToEnum<NearestFieldType<ValueType>>::value == which);
ValueType * MAY_ALIAS ptr = reinterpret_cast<ValueType *>(&storage);
return *ptr;
}
template <typename T>
T & Field::reinterpret()
{
using ValueType = std::decay_t<T>;
ValueType * MAY_ALIAS ptr = reinterpret_cast<ValueType *>(&storage);
return *ptr;
}
template <typename T>
T get(const Field & field)
@ -651,49 +732,6 @@ template <> struct TypeName<Array> { static std::string get() { return "Array";
template <> struct TypeName<Tuple> { static std::string get() { return "Tuple"; } };
template <> struct TypeName<AggregateFunctionStateData> { static std::string get() { return "AggregateFunctionState"; } };
/// char may be signed or unsigned, and behave identically to signed char or unsigned char,
/// but they are always three different types.
/// signedness of char is different in Linux on x86 and Linux on ARM.
template <> struct NearestFieldTypeImpl<char> { using Type = std::conditional_t<is_signed_v<char>, Int64, UInt64>; };
template <> struct NearestFieldTypeImpl<signed char> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<unsigned char> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<UInt16> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<UInt32> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<DayNum> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<UInt128> { using Type = UInt128; };
template <> struct NearestFieldTypeImpl<UUID> { using Type = UInt128; };
template <> struct NearestFieldTypeImpl<Int16> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<Int32> { using Type = Int64; };
/// long and long long are always different types that may behave identically or not.
/// This is different on Linux and Mac.
template <> struct NearestFieldTypeImpl<long> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<long long> { using Type = Int64; };
template <> struct NearestFieldTypeImpl<unsigned long> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<unsigned long long> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<Int128> { using Type = Int128; };
template <> struct NearestFieldTypeImpl<Decimal32> { using Type = DecimalField<Decimal32>; };
template <> struct NearestFieldTypeImpl<Decimal64> { using Type = DecimalField<Decimal64>; };
template <> struct NearestFieldTypeImpl<Decimal128> { using Type = DecimalField<Decimal128>; };
template <> struct NearestFieldTypeImpl<DecimalField<Decimal32>> { using Type = DecimalField<Decimal32>; };
template <> struct NearestFieldTypeImpl<DecimalField<Decimal64>> { using Type = DecimalField<Decimal64>; };
template <> struct NearestFieldTypeImpl<DecimalField<Decimal128>> { using Type = DecimalField<Decimal128>; };
template <> struct NearestFieldTypeImpl<Float32> { using Type = Float64; };
template <> struct NearestFieldTypeImpl<Float64> { using Type = Float64; };
template <> struct NearestFieldTypeImpl<const char *> { using Type = String; };
template <> struct NearestFieldTypeImpl<String> { using Type = String; };
template <> struct NearestFieldTypeImpl<Array> { using Type = Array; };
template <> struct NearestFieldTypeImpl<Tuple> { using Type = Tuple; };
template <> struct NearestFieldTypeImpl<bool> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<Null> { using Type = Null; };
template <> struct NearestFieldTypeImpl<AggregateFunctionStateData> { using Type = AggregateFunctionStateData; };
template <typename T>
decltype(auto) castToNearestFieldType(T && x)
{

View File

@ -100,4 +100,69 @@ size_t getLengthEncodedStringSize(const String & s)
return getLengthEncodedNumberSize(s.size()) + s.size();
}
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
int flags = 0;
switch (type_index)
{
case TypeIndex::UInt8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::Int8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float32:
column_type = ColumnType::MYSQL_TYPE_FLOAT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float64:
column_type = ColumnType::MYSQL_TYPE_DOUBLE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Date:
column_type = ColumnType::MYSQL_TYPE_DATE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::DateTime:
column_type = ColumnType::MYSQL_TYPE_DATETIME;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::String:
case TypeIndex::FixedString:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
default:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
}
return ColumnDefinition(column_name, CharacterSet::binary, 0, column_type, flags, 0);
}
}

View File

@ -130,6 +130,14 @@ enum ColumnType
};
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
enum ColumnDefinitionFlags
{
UNSIGNED_FLAG = 32,
BINARY_FLAG = 128
};
class ProtocolError : public DB::Exception
{
public:
@ -824,19 +832,40 @@ protected:
}
};
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index);
namespace ProtocolText
{
class ResultsetRow : public WritePacket
{
std::vector<String> columns;
const Columns & columns;
int row_num;
size_t payload_size = 0;
std::vector<String> serialized;
public:
ResultsetRow() = default;
void appendColumn(String && value)
ResultsetRow(const DataTypes & data_types, const Columns & columns_, int row_num_)
: columns(columns_)
, row_num(row_num_)
{
payload_size += getLengthEncodedStringSize(value);
columns.emplace_back(std::move(value));
for (size_t i = 0; i < columns.size(); i++)
{
if (columns[i]->isNullAt(row_num))
{
payload_size += 1;
serialized.emplace_back("\xfb");
}
else
{
WriteBufferFromOwnString ostr;
data_types[i]->serializeAsText(*columns[i], row_num, ostr, FormatSettings());
payload_size += getLengthEncodedStringSize(ostr.str());
serialized.push_back(std::move(ostr.str()));
}
}
}
protected:
size_t getPayloadSize() const override
{
@ -845,11 +874,18 @@ protected:
void writePayloadImpl(WriteBuffer & buffer) const override
{
for (const String & column : columns)
writeLengthEncodedString(column, buffer);
for (size_t i = 0; i < columns.size(); i++)
{
if (columns[i]->isNullAt(row_num))
buffer.write(serialized[i].data(), 1);
else
writeLengthEncodedString(serialized[i], buffer);
}
}
};
}
namespace Authentication
{
@ -917,10 +953,7 @@ public:
auto user = context.getUser(user_name);
if (user->authentication.getType() != DB::Authentication::DOUBLE_SHA1_PASSWORD)
throw Exception("Cannot use " + getName() + " auth plugin for user " + user_name + " since its password isn't specified using double SHA1.", ErrorCodes::UNKNOWN_EXCEPTION);
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordHashBinary();
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);
Poco::SHA1Engine engine;

View File

@ -62,7 +62,7 @@ void SettingNumber<Type>::set(const Field & x)
template <typename Type>
void SettingNumber<Type>::set(const String & x)
{
set(parse<Type>(x));
set(completeParse<Type>(x));
}
template <>

View File

@ -5,6 +5,9 @@
namespace DB
{
using TypeListNumbers = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64>;
using TypeListNativeNumbers = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64>;
using TypeListDecimalNumbers = TypeList<Decimal32, Decimal64, Decimal128>;
using TypeListNumbers = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64,
Decimal32, Decimal64, Decimal128>;
}

View File

@ -4,6 +4,7 @@
#include <string>
#include <vector>
#include <common/Types.h>
#include <Common/intExp.h>
namespace DB
@ -145,6 +146,8 @@ struct Decimal
const Decimal<T> & operator /= (const T & x) { value /= x; return *this; }
const Decimal<T> & operator %= (const T & x) { value %= x; return *this; }
static T getScaleMultiplier(UInt32 scale);
T value;
};
@ -170,6 +173,10 @@ template <> struct NativeType<Decimal32> { using Type = Int32; };
template <> struct NativeType<Decimal64> { using Type = Int64; };
template <> struct NativeType<Decimal128> { using Type = Int128; };
template <> inline Int32 Decimal32::getScaleMultiplier(UInt32 scale) { return common::exp10_i32(scale); }
template <> inline Int64 Decimal64::getScaleMultiplier(UInt32 scale) { return common::exp10_i64(scale); }
template <> inline Int128 Decimal128::getScaleMultiplier(UInt32 scale) { return common::exp10_i128(scale); }
inline const char * getTypeName(TypeIndex idx)
{
switch (idx)

View File

@ -52,7 +52,7 @@ struct Less
{
for (auto it = left_columns.begin(), jt = right_columns.begin(); it != left_columns.end(); ++it, ++jt)
{
int res = it->second.direction * it->first->compareAt(a, b, *jt->first, it->second.nulls_direction);
int res = it->description.direction * it->column->compareAt(a, b, *jt->column, it->description.nulls_direction);
if (res < 0)
return true;
else if (res > 0)

View File

@ -2,7 +2,7 @@
#include <Core/Field.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/Quota.h>
#include <Access/QuotaContext.h>
#include <Common/CurrentThread.h>
#include <common/sleep.h>
@ -70,7 +70,7 @@ Block IBlockInputStream::read()
if (limits.mode == LIMITS_CURRENT && !limits.size_limits.check(info.rows, info.bytes, "result", ErrorCodes::TOO_MANY_ROWS_OR_BYTES))
limit_exceeded_need_break = true;
if (quota != nullptr)
if (quota)
checkQuota(res);
}
else
@ -240,12 +240,8 @@ void IBlockInputStream::checkQuota(Block & block)
case LIMITS_CURRENT:
{
time_t current_time = time(nullptr);
double total_elapsed = info.total_stopwatch.elapsedSeconds();
quota->checkAndAddResultRowsBytes(current_time, block.rows(), block.bytes());
quota->checkAndAddExecutionTime(current_time, Poco::Timespan((total_elapsed - prev_elapsed) * 1000000.0));
UInt64 total_elapsed = info.total_stopwatch.elapsedNanoseconds();
quota->used({Quota::RESULT_ROWS, block.rows()}, {Quota::RESULT_BYTES, block.bytes()}, {Quota::EXECUTION_TIME, total_elapsed - prev_elapsed});
prev_elapsed = total_elapsed;
break;
}
@ -291,10 +287,8 @@ void IBlockInputStream::progressImpl(const Progress & value)
limits.speed_limits.throttle(progress.read_rows, progress.read_bytes, total_rows, total_elapsed_microseconds);
if (quota != nullptr && limits.mode == LIMITS_TOTAL)
{
quota->checkAndAddReadRowsBytes(time(nullptr), value.read_rows, value.read_bytes);
}
if (quota && limits.mode == LIMITS_TOTAL)
quota->used({Quota::READ_ROWS, value.read_rows}, {Quota::READ_BYTES, value.read_bytes});
}
}

View File

@ -23,7 +23,7 @@ namespace ErrorCodes
}
class ProcessListElement;
class QuotaForIntervals;
class QuotaContext;
class QueryStatus;
struct SortColumnDescription;
using SortDescription = std::vector<SortColumnDescription>;
@ -220,9 +220,9 @@ public:
/** Set the quota. If you set a quota on the amount of raw data,
* then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits.
*/
virtual void setQuota(QuotaForIntervals & quota_)
virtual void setQuota(const std::shared_ptr<QuotaContext> & quota_)
{
quota = &quota_;
quota = quota_;
}
/// Enable calculation of minimums and maximums by the result columns.
@ -263,6 +263,11 @@ protected:
*/
bool checkTimeLimit();
#ifndef NDEBUG
bool read_prefix_is_called = false;
bool read_suffix_is_called = false;
#endif
private:
bool enabled_extremes = false;
@ -273,8 +278,8 @@ private:
LocalLimits limits;
QuotaForIntervals * quota = nullptr; /// If nullptr - the quota is not used.
double prev_elapsed = 0;
std::shared_ptr<QuotaContext> quota; /// If nullptr - the quota is not used.
UInt64 prev_elapsed = 0;
/// The approximate total number of rows to read. For progress bar.
size_t total_rows_approx = 0;
@ -315,10 +320,6 @@ private:
return;
}
#ifndef NDEBUG
bool read_prefix_is_called = false;
bool read_suffix_is_called = false;
#endif
};
}

View File

@ -57,6 +57,20 @@ NativeBlockInputStream::NativeBlockInputStream(ReadBuffer & istr_, UInt64 server
}
}
// also resets few vars from IBlockInputStream (I didn't want to propagate resetParser upthere)
void NativeBlockInputStream::resetParser()
{
istr_concrete = nullptr;
use_index = false;
#ifndef NDEBUG
read_prefix_is_called = false;
read_suffix_is_called = false;
#endif
is_cancelled.store(false);
is_killed.store(false);
}
void NativeBlockInputStream::readData(const IDataType & type, IColumn & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint)
{

View File

@ -78,6 +78,9 @@ public:
Block getHeader() const override;
void resetParser();
protected:
Block readImpl() override;

View File

@ -1,5 +1,4 @@
#include <DataStreams/ParallelParsingBlockInputStream.h>
#include "ParallelParsingBlockInputStream.h"
namespace DB
{
@ -15,7 +14,7 @@ void ParallelParsingBlockInputStream::segmentatorThreadFunction()
auto & unit = processing_units[current_unit_number];
{
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
segmentator_condvar.wait(lock,
[&]{ return unit.status == READY_TO_INSERT || finished; });
}
@ -85,7 +84,7 @@ void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_unit_n
// except at the end of file. Also see a matching assert in readImpl().
assert(unit.is_last || unit.block_ext.block.size() > 0);
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
unit.status = READY_TO_READ;
reader_condvar.notify_all();
}
@ -99,7 +98,7 @@ void ParallelParsingBlockInputStream::onBackgroundException()
{
tryLogCurrentException(__PRETTY_FUNCTION__);
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
if (!background_exception)
{
background_exception = std::current_exception();
@ -116,7 +115,7 @@ Block ParallelParsingBlockInputStream::readImpl()
/**
* Check for background exception and rethrow it before we return.
*/
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
if (background_exception)
{
lock.unlock();
@ -134,7 +133,7 @@ Block ParallelParsingBlockInputStream::readImpl()
{
// We have read out all the Blocks from the previous Processing Unit,
// wait for the current one to become ready.
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
reader_condvar.wait(lock, [&](){ return unit.status == READY_TO_READ || finished; });
if (finished)
@ -190,7 +189,7 @@ Block ParallelParsingBlockInputStream::readImpl()
else
{
// Pass the unit back to the segmentator.
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
unit.status = READY_TO_INSERT;
segmentator_condvar.notify_all();
}

View File

@ -227,7 +227,7 @@ private:
finished = true;
{
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
segmentator_condvar.notify_all();
reader_condvar.notify_all();
}
@ -255,4 +255,4 @@ private:
void onBackgroundException();
};
};
}

View File

@ -78,7 +78,9 @@ SummingSortedBlockInputStream::SummingSortedBlockInputStream(
else
{
bool is_agg_func = WhichDataType(column.type).isAggregateFunction();
if (!column.type->isSummable() && !is_agg_func)
/// There are special const columns for example after prewere sections.
if ((!column.type->isSummable() && !is_agg_func) || isColumnConst(*column.column))
{
column_numbers_not_to_aggregate.push_back(i);
continue;
@ -198,6 +200,10 @@ SummingSortedBlockInputStream::SummingSortedBlockInputStream(
void SummingSortedBlockInputStream::insertCurrentRowIfNeeded(MutableColumns & merged_columns)
{
/// We have nothing to aggregate. It means that it could be non-zero, because we have columns_not_to_aggregate.
if (columns_to_aggregate.empty())
current_row_is_zero = false;
for (auto & desc : columns_to_aggregate)
{
// Do not insert if the aggregation state hasn't been created

View File

@ -13,14 +13,14 @@ bool DataTypeInterval::equals(const IDataType & rhs) const
void registerDataTypeInterval(DataTypeFactory & factory)
{
factory.registerSimpleDataType("IntervalSecond", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Second)); });
factory.registerSimpleDataType("IntervalMinute", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Minute)); });
factory.registerSimpleDataType("IntervalHour", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Hour)); });
factory.registerSimpleDataType("IntervalDay", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Day)); });
factory.registerSimpleDataType("IntervalWeek", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Week)); });
factory.registerSimpleDataType("IntervalMonth", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Month)); });
factory.registerSimpleDataType("IntervalQuarter", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Quarter)); });
factory.registerSimpleDataType("IntervalYear", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Year)); });
factory.registerSimpleDataType("IntervalSecond", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Second)); });
factory.registerSimpleDataType("IntervalMinute", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Minute)); });
factory.registerSimpleDataType("IntervalHour", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Hour)); });
factory.registerSimpleDataType("IntervalDay", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Day)); });
factory.registerSimpleDataType("IntervalWeek", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Week)); });
factory.registerSimpleDataType("IntervalMonth", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Month)); });
factory.registerSimpleDataType("IntervalQuarter", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Quarter)); });
factory.registerSimpleDataType("IntervalYear", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Year)); });
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <DataTypes/DataTypeNumberBase.h>
#include <Common/IntervalKind.h>
namespace DB
@ -16,47 +17,17 @@ namespace DB
*/
class DataTypeInterval final : public DataTypeNumberBase<Int64>
{
public:
enum Kind
{
Second,
Minute,
Hour,
Day,
Week,
Month,
Quarter,
Year
};
private:
Kind kind;
IntervalKind kind;
public:
static constexpr bool is_parametric = true;
Kind getKind() const { return kind; }
IntervalKind getKind() const { return kind; }
const char * kindToString() const
{
switch (kind)
{
case Second: return "Second";
case Minute: return "Minute";
case Hour: return "Hour";
case Day: return "Day";
case Week: return "Week";
case Month: return "Month";
case Quarter: return "Quarter";
case Year: return "Year";
}
DataTypeInterval(IntervalKind kind_) : kind(kind_) {}
__builtin_unreachable();
}
DataTypeInterval(Kind kind_) : kind(kind_) {}
std::string doGetName() const override { return std::string("Interval") + kindToString(); }
std::string doGetName() const override { return std::string("Interval") + kind.toString(); }
const char * getFamilyName() const override { return "Interval"; }
TypeIndex getTypeId() const override { return TypeIndex::Interval; }

View File

@ -894,7 +894,7 @@ MutableColumnUniquePtr DataTypeLowCardinality::createColumnUniqueImpl(const IDat
if (isColumnedAsNumber(type))
{
MutableColumnUniquePtr column;
TypeListNumbers::forEach(CreateColumnVector(column, *type, creator));
TypeListNativeNumbers::forEach(CreateColumnVector(column, *type, creator));
if (!column)
throw Exception("Unexpected numeric type: " + type->getName(), ErrorCodes::LOGICAL_ERROR);

View File

@ -58,7 +58,7 @@ bool DataTypeDecimal<T>::tryReadText(T & x, ReadBuffer & istr, UInt32 precision,
{
UInt32 unread_scale = scale;
bool done = tryReadDecimalText(istr, x, precision, unread_scale);
x *= getScaleMultiplier(unread_scale);
x *= T::getScaleMultiplier(unread_scale);
return done;
}
@ -70,7 +70,7 @@ void DataTypeDecimal<T>::readText(T & x, ReadBuffer & istr, UInt32 precision, UI
readCSVDecimalText(istr, x, precision, unread_scale);
else
readDecimalText(istr, x, precision, unread_scale);
x *= getScaleMultiplier(unread_scale);
x *= T::getScaleMultiplier(unread_scale);
}
template <typename T>
@ -96,7 +96,7 @@ T DataTypeDecimal<T>::parseFromString(const String & str) const
T x;
UInt32 unread_scale = scale;
readDecimalText(buf, x, precision, unread_scale, true);
x *= getScaleMultiplier(unread_scale);
x *= T::getScaleMultiplier(unread_scale);
return x;
}
@ -271,25 +271,6 @@ void registerDataTypeDecimal(DataTypeFactory & factory)
}
template <>
Decimal32 DataTypeDecimal<Decimal32>::getScaleMultiplier(UInt32 scale_)
{
return decimalScaleMultiplier<Int32>(scale_);
}
template <>
Decimal64 DataTypeDecimal<Decimal64>::getScaleMultiplier(UInt32 scale_)
{
return decimalScaleMultiplier<Int64>(scale_);
}
template <>
Decimal128 DataTypeDecimal<Decimal128>::getScaleMultiplier(UInt32 scale_)
{
return decimalScaleMultiplier<Int128>(scale_);
}
/// Explicit template instantiations.
template class DataTypeDecimal<Decimal32>;
template class DataTypeDecimal<Decimal64>;

View File

@ -130,7 +130,7 @@ public:
UInt32 getPrecision() const { return precision; }
UInt32 getScale() const { return scale; }
T getScaleMultiplier() const { return getScaleMultiplier(scale); }
T getScaleMultiplier() const { return T::getScaleMultiplier(scale); }
T wholePart(T x) const
{
@ -148,7 +148,7 @@ public:
return x % getScaleMultiplier();
}
T maxWholeValue() const { return getScaleMultiplier(maxPrecision() - scale) - T(1); }
T maxWholeValue() const { return T::getScaleMultiplier(maxPrecision() - scale) - T(1); }
bool canStoreWhole(T x) const
{
@ -165,7 +165,7 @@ public:
if (getScale() < x.getScale())
throw Exception("Decimal result's scale is less then argiment's one", ErrorCodes::ARGUMENT_OUT_OF_BOUND);
UInt32 scale_delta = getScale() - x.getScale(); /// scale_delta >= 0
return getScaleMultiplier(scale_delta);
return T::getScaleMultiplier(scale_delta);
}
template <typename U>
@ -181,7 +181,6 @@ public:
void readText(T & x, ReadBuffer & istr, bool csv = false) const { readText(x, istr, precision, scale, csv); }
static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale, bool csv = false);
static bool tryReadText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale);
static T getScaleMultiplier(UInt32 scale);
private:
const UInt32 precision;
@ -264,12 +263,12 @@ convertDecimals(const typename FromDataType::FieldType & value, UInt32 scale_fro
MaxNativeType converted_value;
if (scale_to > scale_from)
{
converted_value = DataTypeDecimal<MaxFieldType>::getScaleMultiplier(scale_to - scale_from);
converted_value = MaxFieldType::getScaleMultiplier(scale_to - scale_from);
if (common::mulOverflow(static_cast<MaxNativeType>(value), converted_value, converted_value))
throw Exception("Decimal convert overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
else
converted_value = value / DataTypeDecimal<MaxFieldType>::getScaleMultiplier(scale_from - scale_to);
converted_value = value / MaxFieldType::getScaleMultiplier(scale_from - scale_to);
if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType))
{
@ -289,7 +288,7 @@ convertFromDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
using ToFieldType = typename ToDataType::FieldType;
if constexpr (std::is_floating_point_v<ToFieldType>)
return static_cast<ToFieldType>(value) / FromDataType::getScaleMultiplier(scale);
return static_cast<ToFieldType>(value) / FromFieldType::getScaleMultiplier(scale);
else
{
FromFieldType converted_value = convertDecimals<FromDataType, FromDataType>(value, scale, 0);
@ -320,14 +319,15 @@ inline std::enable_if_t<IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDa
convertToDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
{
using FromFieldType = typename FromDataType::FieldType;
using ToNativeType = typename ToDataType::FieldType::NativeType;
using ToFieldType = typename ToDataType::FieldType;
using ToNativeType = typename ToFieldType::NativeType;
if constexpr (std::is_floating_point_v<FromFieldType>)
{
if (!std::isfinite(value))
throw Exception("Decimal convert overflow. Cannot convert infinity or NaN to decimal", ErrorCodes::DECIMAL_OVERFLOW);
auto out = value * ToDataType::getScaleMultiplier(scale);
auto out = value * ToFieldType::getScaleMultiplier(scale);
if constexpr (std::is_same_v<ToNativeType, Int128>)
{
static constexpr __int128 min_int128 = __int128(0x8000000000000000ll) << 64;

View File

@ -48,7 +48,7 @@ public:
double getLoadFactor() const override { return static_cast<double>(element_count.load(std::memory_order_relaxed)) / size; }
bool isCached() const override { return true; }
bool supportUpdates() const override { return false; }
std::shared_ptr<const IExternalLoadable> clone() const override
{

View File

@ -3,8 +3,8 @@
#include <Columns/ColumnsNumber.h>
#include <Common/ProfilingScopedRWLock.h>
#include <Common/typeid_cast.h>
#include <common/DateLUT.h>
#include <DataStreams/IBlockInputStream.h>
#include <ext/chrono_io.h>
#include <ext/map.h>
#include <ext/range.h>
#include <ext/size.h>
@ -334,7 +334,7 @@ void CacheDictionary::update(
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
tryLogException(last_exception, log, "Could not update cache dictionary '" + getName() +
"', next update is scheduled at " + DateLUT::instance().timeToString(std::chrono::system_clock::to_time_t(backoff_end_time)));
"', next update is scheduled at " + ext::to_string(backoff_end_time));
}
}

View File

@ -71,7 +71,7 @@ public:
double getLoadFactor() const override { return static_cast<double>(element_count.load(std::memory_order_relaxed)) / size; }
bool isCached() const override { return true; }
bool supportUpdates() const override { return false; }
std::shared_ptr<const IExternalLoadable> clone() const override
{

View File

@ -46,8 +46,6 @@ public:
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
bool isCached() const override { return false; }
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<ComplexKeyHashedDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty, saved_block);

View File

@ -43,8 +43,6 @@ public:
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
bool isCached() const override { return false; }
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<FlatDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty, saved_block);

View File

@ -5,6 +5,8 @@
#include <IO/ConnectionTimeouts.h>
#include <IO/ReadWriteBufferFromHTTP.h>
#include <IO/WriteBufferFromOStream.h>
#include <IO/WriteBufferFromString.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <Poco/Net/HTTPRequest.h>
#include <common/logger_useful.h>
@ -82,12 +84,9 @@ void HTTPDictionarySource::getUpdateFieldAndDate(Poco::URI & uri)
auto tmp_time = update_time;
update_time = std::chrono::system_clock::now();
time_t hr_time = std::chrono::system_clock::to_time_t(tmp_time) - 1;
char buffer[80];
struct tm * timeinfo;
timeinfo = localtime(&hr_time);
strftime(buffer, 80, "%Y-%m-%d %H:%M:%S", timeinfo);
std::string str_time(buffer);
uri.addQueryParameter(update_field, str_time);
WriteBufferFromOwnString out;
writeDateTimeText(hr_time, out);
uri.addQueryParameter(update_field, out.str());
}
else
{

View File

@ -48,8 +48,6 @@ public:
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
bool isCached() const override { return false; }
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<HashedDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty, sparse, saved_block);

View File

@ -37,8 +37,6 @@ struct IDictionaryBase : public IExternalLoadable
virtual double getLoadFactor() const = 0;
virtual bool isCached() const = 0;
virtual const IDictionarySource * getSource() const = 0;
virtual const DictionaryStructure & getStructure() const = 0;
@ -47,7 +45,7 @@ struct IDictionaryBase : public IExternalLoadable
virtual BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const = 0;
bool supportUpdates() const override { return !isCached(); }
bool supportUpdates() const override { return true; }
bool isModified() const override
{

View File

@ -38,8 +38,6 @@ public:
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
bool isCached() const override { return false; }
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<RangeHashedDictionary>(dictionary_name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty);

View File

@ -47,8 +47,6 @@ public:
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
bool isCached() const override { return false; }
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<TrieDictionary>(name, dict_struct, source_ptr->clone(), dict_lifetime, require_nonempty);

View File

@ -176,7 +176,7 @@ void buildSingleAttribute(
AutoPtr<Element> null_value_element(doc->createElement("null_value"));
String null_value_str;
if (dict_attr->default_value)
null_value_str = queryToString(dict_attr->default_value);
null_value_str = getUnescapedFieldString(dict_attr->default_value->as<ASTLiteral>()->value);
AutoPtr<Text> null_value(doc->createTextNode(null_value_str));
null_value_element->appendChild(null_value);
attribute_element->appendChild(null_value_element);
@ -184,7 +184,19 @@ void buildSingleAttribute(
if (dict_attr->expression != nullptr)
{
AutoPtr<Element> expression_element(doc->createElement("expression"));
AutoPtr<Text> expression(doc->createTextNode(queryToString(dict_attr->expression)));
/// EXPRESSION PROPERTY should be expression or string
String expression_str;
if (const auto * literal = dict_attr->expression->as<ASTLiteral>();
literal && literal->value.getType() == Field::Types::String)
{
expression_str = getUnescapedFieldString(literal->value);
}
else
expression_str = queryToString(dict_attr->expression);
AutoPtr<Text> expression(doc->createTextNode(expression_str));
expression_element->appendChild(expression);
attribute_element->appendChild(expression_element);
}

View File

@ -281,6 +281,8 @@ void registerInputFormatProcessorTSKV(FormatFactory & factory);
void registerOutputFormatProcessorTSKV(FormatFactory & factory);
void registerInputFormatProcessorJSONEachRow(FormatFactory & factory);
void registerOutputFormatProcessorJSONEachRow(FormatFactory & factory);
void registerInputFormatProcessorJSONCompactEachRow(FormatFactory & factory);
void registerOutputFormatProcessorJSONCompactEachRow(FormatFactory & factory);
void registerInputFormatProcessorParquet(FormatFactory & factory);
void registerInputFormatProcessorORC(FormatFactory & factory);
void registerOutputFormatProcessorParquet(FormatFactory & factory);
@ -336,6 +338,8 @@ FormatFactory::FormatFactory()
registerOutputFormatProcessorTSKV(*this);
registerInputFormatProcessorJSONEachRow(*this);
registerOutputFormatProcessorJSONEachRow(*this);
registerInputFormatProcessorJSONCompactEachRow(*this);
registerOutputFormatProcessorJSONCompactEachRow(*this);
registerInputFormatProcessorProtobuf(*this);
registerOutputFormatProcessorProtobuf(*this);
registerInputFormatProcessorCapnProto(*this);

View File

@ -508,7 +508,7 @@ class FunctionBinaryArithmetic : public IFunction
}
std::stringstream function_name;
function_name << (function_is_plus ? "add" : "subtract") << interval_data_type->kindToString() << 's';
function_name << (function_is_plus ? "add" : "subtract") << interval_data_type->getKind().toString() << 's';
return FunctionFactory::instance().get(function_name.str(), context);
}

View File

@ -735,7 +735,7 @@ struct NameToDecimal128 { static constexpr auto name = "toDecimal128"; };
struct NameToInterval ## INTERVAL_KIND \
{ \
static constexpr auto name = "toInterval" #INTERVAL_KIND; \
static constexpr int kind = DataTypeInterval::INTERVAL_KIND; \
static constexpr auto kind = IntervalKind::INTERVAL_KIND; \
};
DEFINE_NAME_TO_INTERVAL(Second)
@ -786,7 +786,7 @@ public:
if constexpr (std::is_same_v<ToDataType, DataTypeInterval>)
{
return std::make_shared<DataTypeInterval>(DataTypeInterval::Kind(Name::kind));
return std::make_shared<DataTypeInterval>(Name::kind);
}
else if constexpr (to_decimal)
{

View File

@ -707,6 +707,20 @@ private:
ErrorCodes::ILLEGAL_COLUMN);
}
template <bool first>
void executeGeneric(const IColumn * column, typename ColumnVector<ToType>::Container & vec_to)
{
for (size_t i = 0, size = column->size(); i < size; ++i)
{
StringRef bytes = column->getDataAt(i);
const ToType h = Impl::apply(bytes.data, bytes.size);
if (first)
vec_to[i] = h;
else
vec_to[i] = Impl::combineHashes(vec_to[i], h);
}
}
template <bool first>
void executeString(const IColumn * column, typename ColumnVector<ToType>::Container & vec_to)
{
@ -843,8 +857,7 @@ private:
else if (which.isFixedString()) executeString<first>(icolumn, vec_to);
else if (which.isArray()) executeArray<first>(from_type, icolumn, vec_to);
else
throw Exception("Unexpected type " + from_type->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
executeGeneric<first>(icolumn, vec_to);
}
void executeForArgument(const IDataType * type, const IColumn * column, typename ColumnVector<ToType>::Container & vec_to, bool & is_first)

View File

@ -20,6 +20,7 @@ void registerFunctionsJSON(FunctionFactory & factory)
factory.registerFunction<FunctionJSON<NameJSONExtract, JSONExtractImpl>>();
factory.registerFunction<FunctionJSON<NameJSONExtractKeysAndValues, JSONExtractKeysAndValuesImpl>>();
factory.registerFunction<FunctionJSON<NameJSONExtractRaw, JSONExtractRawImpl>>();
factory.registerFunction<FunctionJSON<NameJSONExtractArrayRaw, JSONExtractArrayRawImpl>>();
}
}

View File

@ -291,6 +291,7 @@ struct NameJSONExtractString { static constexpr auto name{"JSONExtractString"};
struct NameJSONExtract { static constexpr auto name{"JSONExtract"}; };
struct NameJSONExtractKeysAndValues { static constexpr auto name{"JSONExtractKeysAndValues"}; };
struct NameJSONExtractRaw { static constexpr auto name{"JSONExtractRaw"}; };
struct NameJSONExtractArrayRaw { static constexpr auto name{"JSONExtractArrayRaw"}; };
template <typename JSONParser>
@ -1088,4 +1089,39 @@ private:
}
};
template <typename JSONParser>
class JSONExtractArrayRawImpl
{
public:
static DataTypePtr getType(const char *, const ColumnsWithTypeAndName &)
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>());
}
using Iterator = typename JSONParser::Iterator;
static bool addValueToColumn(IColumn & dest, const Iterator & it)
{
if (!JSONParser::isArray(it))
{
return false;
}
ColumnArray & col_res = assert_cast<ColumnArray &>(dest);
Iterator array_it = it;
size_t size = 0;
if (JSONParser::firstArrayElement(array_it))
{
do
{
JSONExtractRawImpl<JSONParser>::addValueToColumn(col_res.getData(), array_it);
++size;
} while (JSONParser::nextArrayElement(array_it));
}
col_res.getOffsets().push_back(col_res.getOffsets().back() + size);
return true;
}
static constexpr size_t num_extra_arguments = 0;
static void prepare(const char *, const Block &, const ColumnNumbers &, size_t) {}
};
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <Core/Types.h>
#include <Common/FieldVisitors.h>
#include "Sources.h"
#include "Sinks.h"
@ -79,8 +80,16 @@ inline ALWAYS_INLINE void writeSlice(const NumericArraySlice<T> & slice, Generic
{
for (size_t i = 0; i < slice.size; ++i)
{
Field field = T(slice.data[i]);
sink.elements.insert(field);
if constexpr (IsDecimalNumber<T>)
{
DecimalField field(T(slice.data[i]), 0); /// TODO: Decimal scale
sink.elements.insert(field);
}
else
{
Field field = T(slice.data[i]);
sink.elements.insert(field);
}
}
sink.current_offset += slice.size;
}
@ -422,9 +431,18 @@ bool sliceHasImpl(const FirstSliceType & first, const SecondSliceType & second,
}
template <typename T, typename U>
bool sliceEqualElements(const NumericArraySlice<T> & first, const NumericArraySlice<U> & second, size_t first_ind, size_t second_ind)
bool sliceEqualElements(const NumericArraySlice<T> & first [[maybe_unused]],
const NumericArraySlice<U> & second [[maybe_unused]],
size_t first_ind [[maybe_unused]],
size_t second_ind [[maybe_unused]])
{
return accurate::equalsOp(first.data[first_ind], second.data[second_ind]);
/// TODO: Decimal scale
if constexpr (IsDecimalNumber<T> && IsDecimalNumber<U>)
return accurate::equalsOp(typename T::NativeType(first.data[first_ind]), typename U::NativeType(second.data[second_ind]));
else if constexpr (IsDecimalNumber<T> || IsDecimalNumber<U>)
return false;
else
return accurate::equalsOp(first.data[first_ind], second.data[second_ind]);
}
template <typename T>

Some files were not shown because too many files have changed in this diff Show More