Merge branch 'master' of github.com:clickhouse/ClickHouse

This commit is contained in:
Ivan Blinkov 2019-12-06 16:28:34 +03:00
commit d5a0833595
259 changed files with 9126 additions and 1673 deletions

View File

@ -14,5 +14,4 @@ ClickHouse is an open-source column-oriented database management system that all
## Upcoming Events
* [ClickHouse Meetup in San Francisco](https://www.eventbrite.com/e/clickhouse-december-meetup-registration-78642047481) on December 3.
* [ClickHouse Meetup in Moscow](https://yandex.ru/promo/clickhouse/moscow-december-2019) on December 11.

View File

@ -432,6 +432,8 @@ if (USE_JEMALLOC)
if(NOT MAKE_STATIC_LIBRARIES AND ${JEMALLOC_LIBRARIES} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$")
# mallctl in dbms/src/Interpreters/AsynchronousMetrics.cpp
# Actually we link JEMALLOC to almost all libraries.
# This is just hotfix for some uninvestigated problem.
target_link_libraries(clickhouse_interpreters PRIVATE ${JEMALLOC_LIBRARIES})
endif()
endif ()

View File

@ -30,6 +30,11 @@ if (Poco_Data_FOUND)
set(CLICKHOUSE_ODBC_BRIDGE_LINK ${CLICKHOUSE_ODBC_BRIDGE_LINK} PRIVATE ${Poco_Data_LIBRARY})
set(CLICKHOUSE_ODBC_BRIDGE_INCLUDE ${CLICKHOUSE_ODBC_BRIDGE_INCLUDE} SYSTEM PRIVATE ${Poco_Data_INCLUDE_DIR})
endif ()
if (USE_JEMALLOC)
# We need to link jemalloc directly to odbc-bridge-library, because in other case
# we will build it with default malloc.
set(CLICKHOUSE_ODBC_BRIDGE_LINK ${CLICKHOUSE_ODBC_BRIDGE_LINK} PRIVATE ${JEMALLOC_LIBRARIES})
endif()
clickhouse_program_add_library(odbc-bridge)

View File

@ -34,7 +34,6 @@
#include <IO/WriteBufferFromTemporaryFile.h>
#include <DataStreams/IBlockInputStream.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/Quota.h>
#include <Common/typeid_cast.h>
#include <Poco/Net/HTTPStream.h>

View File

@ -243,6 +243,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
}
#endif
global_context->setRemoteHostFilter(config());
std::string path = getCanonicalPath(config().getString("path", DBMS_DEFAULT_PATH));
std::string default_database = config().getString("default_database", "default");

View File

@ -19,7 +19,6 @@
#include <DataStreams/NativeBlockInputStream.h>
#include <DataStreams/NativeBlockOutputStream.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/Quota.h>
#include <Interpreters/TablesStatus.h>
#include <Interpreters/InternalTextLogsQueue.h>
#include <Storages/StorageMemory.h>

View File

@ -3,6 +3,25 @@
NOTE: User and query level settings are set up in "users.xml" file.
-->
<yandex>
<!-- The list of hosts allowed to use in URL-related storage engines and table functions.
If this section is not present in configuration, all hosts are allowed.
-->
<remote_url_allow_hosts>
<!-- Host should be specified exactly as in URL. The name is checked before DNS resolution.
Example: "yandex.ru", "yandex.ru." and "www.yandex.ru" are different hosts.
If port is explicitly specified in URL, the host:port is checked as a whole.
If host specified here without port, any port with this host allowed.
"yandex.ru" -> "yandex.ru:443", "yandex.ru:80" etc. is allowed, but "yandex.ru:80" -> only "yandex.ru:80" is allowed.
If the host is specified as IP address, it is checked as specified in URL. Example: "[2a02:6b8:a::a]".
If there are redirects and support for redirects is enabled, every redirect (the Location field) is checked.
-->
<!-- Regular expression can be specified. RE2 engine is used for regexps.
Regexps are not aligned: don't forget to add ^ and $. Also don't forget to escape dot (.) metacharacter
(forgetting to do so is a common source of error).
-->
</remote_url_allow_hosts>
<logger>
<!-- Possible levels: https://github.com/pocoproject/poco/blob/develop/Foundation/include/Poco/Logger.h#L105 -->
<level>trace</level>
@ -15,7 +34,6 @@
<!--display_name>production</display_name--> <!-- It is the name that will be shown in the client -->
<http_port>8123</http_port>
<tcp_port>9000</tcp_port>
<!-- For HTTPS and SSL over native protocol. -->
<!--
<https_port>8443</https_port>

View File

@ -0,0 +1,52 @@
#include <Access/AccessControlManager.h>
#include <Access/MultipleAccessStorage.h>
#include <Access/MemoryAccessStorage.h>
#include <Access/UsersConfigAccessStorage.h>
#include <Access/QuotaContextFactory.h>
namespace DB
{
namespace
{
std::vector<std::unique_ptr<IAccessStorage>> createStorages()
{
std::vector<std::unique_ptr<IAccessStorage>> list;
list.emplace_back(std::make_unique<MemoryAccessStorage>());
list.emplace_back(std::make_unique<UsersConfigAccessStorage>());
return list;
}
}
AccessControlManager::AccessControlManager()
: MultipleAccessStorage(createStorages()),
quota_context_factory(std::make_unique<QuotaContextFactory>(*this))
{
}
AccessControlManager::~AccessControlManager()
{
}
void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguration & users_config)
{
auto & users_config_access_storage = dynamic_cast<UsersConfigAccessStorage &>(getStorageByIndex(1));
users_config_access_storage.loadFromConfig(users_config);
}
std::shared_ptr<QuotaContext> AccessControlManager::createQuotaContext(
const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key)
{
return quota_context_factory->createContext(user_name, address, custom_quota_key);
}
std::vector<QuotaUsageInfo> AccessControlManager::getQuotaUsageInfo() const
{
return quota_context_factory->getUsageInfo();
}
}

View File

@ -0,0 +1,45 @@
#pragma once
#include <Access/MultipleAccessStorage.h>
#include <Poco/AutoPtr.h>
#include <memory>
namespace Poco
{
namespace Net
{
class IPAddress;
}
namespace Util
{
class AbstractConfiguration;
}
}
namespace DB
{
class QuotaContext;
class QuotaContextFactory;
struct QuotaUsageInfo;
/// Manages access control entities.
class AccessControlManager : public MultipleAccessStorage
{
public:
AccessControlManager();
~AccessControlManager();
void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config);
std::shared_ptr<QuotaContext>
createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key);
std::vector<QuotaUsageInfo> getQuotaUsageInfo() const;
private:
std::unique_ptr<QuotaContextFactory> quota_context_factory;
};
}

View File

@ -0,0 +1,19 @@
#include <Access/IAccessEntity.h>
#include <Access/Quota.h>
#include <common/demangle.h>
namespace DB
{
String IAccessEntity::getTypeName(std::type_index type)
{
if (type == typeid(Quota))
return "Quota";
return demangle(type.name());
}
bool IAccessEntity::equal(const IAccessEntity & other) const
{
return (full_name == other.full_name) && (getType() == other.getType());
}
}

View File

@ -0,0 +1,49 @@
#pragma once
#include <Core/Types.h>
#include <Common/typeid_cast.h>
#include <memory>
#include <typeindex>
namespace DB
{
/// Access entity is a set of data which have a name and a type. Access entity control something related to the access control.
/// Entities can be stored to a file or another storage, see IAccessStorage.
struct IAccessEntity
{
IAccessEntity() = default;
IAccessEntity(const IAccessEntity &) = default;
virtual ~IAccessEntity() = default;
virtual std::shared_ptr<IAccessEntity> clone() const = 0;
std::type_index getType() const { return typeid(*this); }
static String getTypeName(std::type_index type);
const String getTypeName() const { return getTypeName(getType()); }
template <typename EntityType>
bool isTypeOf() const { return isTypeOf(typeid(EntityType)); }
bool isTypeOf(std::type_index type) const { return type == getType(); }
virtual void setName(const String & name_) { full_name = name_; }
virtual String getName() const { return full_name; }
String getFullName() const { return full_name; }
friend bool operator ==(const IAccessEntity & lhs, const IAccessEntity & rhs) { return lhs.equal(rhs); }
friend bool operator !=(const IAccessEntity & lhs, const IAccessEntity & rhs) { return !(lhs == rhs); }
protected:
String full_name;
virtual bool equal(const IAccessEntity & other) const;
/// Helper function to define clone() in the derived classes.
template <typename EntityType>
std::shared_ptr<IAccessEntity> cloneImpl() const
{
return std::make_shared<EntityType>(typeid_cast<const EntityType &>(*this));
}
};
using AccessEntityPtr = std::shared_ptr<const IAccessEntity>;
}

View File

@ -0,0 +1,450 @@
#include <Access/IAccessStorage.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <IO/WriteHelpers.h>
#include <Poco/UUIDGenerator.h>
#include <Poco/Logger.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_CAST;
extern const int ACCESS_ENTITY_NOT_FOUND;
extern const int ACCESS_ENTITY_ALREADY_EXISTS;
extern const int ACCESS_ENTITY_FOUND_DUPLICATES;
extern const int ACCESS_ENTITY_STORAGE_READONLY;
}
std::vector<UUID> IAccessStorage::findAll(std::type_index type) const
{
return findAllImpl(type);
}
std::optional<UUID> IAccessStorage::find(std::type_index type, const String & name) const
{
return findImpl(type, name);
}
std::vector<UUID> IAccessStorage::find(std::type_index type, const Strings & names) const
{
std::vector<UUID> ids;
ids.reserve(names.size());
for (const String & name : names)
{
auto id = findImpl(type, name);
if (id)
ids.push_back(*id);
}
return ids;
}
UUID IAccessStorage::getID(std::type_index type, const String & name) const
{
auto id = findImpl(type, name);
if (id)
return *id;
throwNotFound(type, name);
}
std::vector<UUID> IAccessStorage::getIDs(std::type_index type, const Strings & names) const
{
std::vector<UUID> ids;
ids.reserve(names.size());
for (const String & name : names)
ids.push_back(getID(type, name));
return ids;
}
bool IAccessStorage::exists(const UUID & id) const
{
return existsImpl(id);
}
AccessEntityPtr IAccessStorage::tryReadBase(const UUID & id) const
{
try
{
return readImpl(id);
}
catch (Exception &)
{
return nullptr;
}
}
String IAccessStorage::readName(const UUID & id) const
{
return readNameImpl(id);
}
std::optional<String> IAccessStorage::tryReadName(const UUID & id) const
{
try
{
return readNameImpl(id);
}
catch (Exception &)
{
return {};
}
}
UUID IAccessStorage::insert(const AccessEntityPtr & entity)
{
return insertImpl(entity, false);
}
std::vector<UUID> IAccessStorage::insert(const std::vector<AccessEntityPtr> & multiple_entities)
{
std::vector<UUID> ids;
ids.reserve(multiple_entities.size());
String error_message;
for (const auto & entity : multiple_entities)
{
try
{
ids.push_back(insertImpl(entity, false));
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS)
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
}
}
if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
return ids;
}
std::optional<UUID> IAccessStorage::tryInsert(const AccessEntityPtr & entity)
{
try
{
return insertImpl(entity, false);
}
catch (Exception &)
{
return {};
}
}
std::vector<UUID> IAccessStorage::tryInsert(const std::vector<AccessEntityPtr> & multiple_entities)
{
std::vector<UUID> ids;
ids.reserve(multiple_entities.size());
for (const auto & entity : multiple_entities)
{
try
{
ids.push_back(insertImpl(entity, false));
}
catch (Exception &)
{
}
}
return ids;
}
UUID IAccessStorage::insertOrReplace(const AccessEntityPtr & entity)
{
return insertImpl(entity, true);
}
std::vector<UUID> IAccessStorage::insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities)
{
std::vector<UUID> ids;
ids.reserve(multiple_entities.size());
for (const auto & entity : multiple_entities)
ids.push_back(insertImpl(entity, true));
return ids;
}
void IAccessStorage::remove(const UUID & id)
{
removeImpl(id);
}
void IAccessStorage::remove(const std::vector<UUID> & ids)
{
String error_message;
for (const auto & id : ids)
{
try
{
removeImpl(id);
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND)
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
}
}
if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
bool IAccessStorage::tryRemove(const UUID & id)
{
try
{
removeImpl(id);
return true;
}
catch (Exception &)
{
return false;
}
}
std::vector<UUID> IAccessStorage::tryRemove(const std::vector<UUID> & ids)
{
std::vector<UUID> removed;
removed.reserve(ids.size());
for (const auto & id : ids)
{
try
{
removeImpl(id);
removed.push_back(id);
}
catch (Exception &)
{
}
}
return removed;
}
void IAccessStorage::update(const UUID & id, const UpdateFunc & update_func)
{
updateImpl(id, update_func);
}
void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & update_func)
{
String error_message;
for (const auto & id : ids)
{
try
{
updateImpl(id, update_func);
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND)
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
}
}
if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
bool IAccessStorage::tryUpdate(const UUID & id, const UpdateFunc & update_func)
{
try
{
updateImpl(id, update_func);
return true;
}
catch (Exception &)
{
return false;
}
}
std::vector<UUID> IAccessStorage::tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func)
{
std::vector<UUID> updated;
updated.reserve(ids.size());
for (const auto & id : ids)
{
try
{
updateImpl(id, update_func);
updated.push_back(id);
}
catch (Exception &)
{
}
}
return updated;
}
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(type, handler);
}
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(id, handler);
}
IAccessStorage::SubscriptionPtr IAccessStorage::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
if (ids.empty())
return nullptr;
if (ids.size() == 1)
return subscribeForChangesImpl(ids[0], handler);
std::vector<SubscriptionPtr> subscriptions;
subscriptions.reserve(ids.size());
for (const auto & id : ids)
{
auto subscription = subscribeForChangesImpl(id, handler);
if (subscription)
subscriptions.push_back(std::move(subscription));
}
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(std::vector<SubscriptionPtr> subscriptions_)
: subscriptions(std::move(subscriptions_)) {}
private:
std::vector<SubscriptionPtr> subscriptions;
};
return std::make_unique<SubscriptionImpl>(std::move(subscriptions));
}
bool IAccessStorage::hasSubscription(std::type_index type) const
{
return hasSubscriptionImpl(type);
}
bool IAccessStorage::hasSubscription(const UUID & id) const
{
return hasSubscriptionImpl(id);
}
void IAccessStorage::notify(const Notifications & notifications)
{
for (const auto & [fn, id, new_entity] : notifications)
fn(id, new_entity);
}
UUID IAccessStorage::generateRandomID()
{
static Poco::UUIDGenerator generator;
UUID id;
generator.createRandom().copyTo(reinterpret_cast<char *>(&id));
return id;
}
Poco::Logger * IAccessStorage::getLogger() const
{
Poco::Logger * ptr = log.load();
if (!ptr)
log.store(ptr = &Poco::Logger::get("Access(" + storage_name + ")"), std::memory_order_relaxed);
return ptr;
}
void IAccessStorage::throwNotFound(const UUID & id) const
{
throw Exception("ID {" + toString(id) + "} not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
void IAccessStorage::throwNotFound(std::type_index type, const String & name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
}
void IAccessStorage::throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) const
{
throw Exception(
"ID {" + toString(id) + "}: " + getTypeName(type) + backQuote(name) + " expected to be of type " + getTypeName(required_type),
ErrorCodes::BAD_CAST);
}
void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(name) + ": cannot insert because the ID {" + toString(id) + "} is already used by "
+ getTypeName(existing_type) + " " + backQuote(existing_name) + " in " + getStorageName(),
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
}
void IAccessStorage::throwNameCollisionCannotInsert(std::type_index type, const String & name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(name) + ": cannot insert because " + getTypeName(type) + " " + backQuote(name)
+ " already exists in " + getStorageName(),
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
}
void IAccessStorage::throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const
{
throw Exception(
getTypeName(type) + " " + backQuote(old_name) + ": cannot rename to " + backQuote(new_name) + " because " + getTypeName(type) + " "
+ backQuote(new_name) + " already exists in " + getStorageName(),
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
}
void IAccessStorage::throwReadonlyCannotInsert(std::type_index type, const String & name) const
{
throw Exception(
"Cannot insert " + getTypeName(type) + " " + backQuote(name) + " to " + getStorageName() + " because this storage is readonly",
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
}
void IAccessStorage::throwReadonlyCannotUpdate(std::type_index type, const String & name) const
{
throw Exception(
"Cannot update " + getTypeName(type) + " " + backQuote(name) + " in " + getStorageName() + " because this storage is readonly",
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
}
void IAccessStorage::throwReadonlyCannotRemove(std::type_index type, const String & name) const
{
throw Exception(
"Cannot remove " + getTypeName(type) + " " + backQuote(name) + " from " + getStorageName() + " because this storage is readonly",
ErrorCodes::ACCESS_ENTITY_STORAGE_READONLY);
}
}

View File

@ -0,0 +1,209 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <functional>
#include <optional>
#include <vector>
#include <atomic>
namespace Poco { class Logger; }
namespace DB
{
/// Contains entities, i.e. instances of classes derived from IAccessEntity.
/// The implementations of this class MUST be thread-safe.
class IAccessStorage
{
public:
IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {}
virtual ~IAccessStorage() {}
/// Returns the name of this storage.
const String & getStorageName() const { return storage_name; }
/// Returns the identifiers of all the entities of a specified type contained in the storage.
std::vector<UUID> findAll(std::type_index type) const;
template <typename EntityType>
std::vector<UUID> findAll() const { return findAll(typeid(EntityType)); }
/// Searchs for an entity with specified type and name. Returns std::nullopt if not found.
std::optional<UUID> find(std::type_index type, const String & name) const;
template <typename EntityType>
std::optional<UUID> find(const String & name) const { return find(typeid(EntityType), name); }
std::vector<UUID> find(std::type_index type, const Strings & names) const;
template <typename EntityType>
std::vector<UUID> find(const Strings & names) const { return find(typeid(EntityType), names); }
/// Searchs for an entity with specified name and type. Throws an exception if not found.
UUID getID(std::type_index type, const String & name) const;
template <typename EntityType>
UUID getID(const String & name) const { return getID(typeid(EntityType), name); }
std::vector<UUID> getIDs(std::type_index type, const Strings & names) const;
template <typename EntityType>
std::vector<UUID> getIDs(const Strings & names) const { return getIDs(typeid(EntityType), names); }
/// Returns whether there is an entity with such identifier in the storage.
bool exists(const UUID & id) const;
/// Reads an entity. Throws an exception if not found.
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> read(const UUID & id) const;
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> read(const String & name) const;
/// Reads an entity. Returns nullptr if not found.
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> tryRead(const UUID & id) const;
template <typename EntityType = IAccessEntity>
std::shared_ptr<const EntityType> tryRead(const String & name) const;
/// Reads only name of an entity.
String readName(const UUID & id) const;
std::optional<String> tryReadName(const UUID & id) const;
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
/// Throws an exception if the specified name already exists.
UUID insert(const AccessEntityPtr & entity);
std::vector<UUID> insert(const std::vector<AccessEntityPtr> & multiple_entities);
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
std::optional<UUID> tryInsert(const AccessEntityPtr & entity);
std::vector<UUID> tryInsert(const std::vector<AccessEntityPtr> & multiple_entities);
/// Inserts an entity to the storage. Return ID of a new entry in the storage.
/// Replaces an existing entry in the storage if the specified name already exists.
UUID insertOrReplace(const AccessEntityPtr & entity);
std::vector<UUID> insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities);
/// Removes an entity from the storage. Throws an exception if couldn't remove.
void remove(const UUID & id);
void remove(const std::vector<UUID> & ids);
/// Removes an entity from the storage. Returns false if couldn't remove.
bool tryRemove(const UUID & id);
/// Removes multiple entities from the storage. Returns the list of successfully dropped.
std::vector<UUID> tryRemove(const std::vector<UUID> & ids);
using UpdateFunc = std::function<AccessEntityPtr(const AccessEntityPtr &)>;
/// Updates an entity stored in the storage. Throws an exception if couldn't update.
void update(const UUID & id, const UpdateFunc & update_func);
void update(const std::vector<UUID> & ids, const UpdateFunc & update_func);
/// Updates an entity stored in the storage. Returns false if couldn't update.
bool tryUpdate(const UUID & id, const UpdateFunc & update_func);
/// Updates multiple entities in the storage. Returns the list of successfully updated.
std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func);
class Subscription
{
public:
virtual ~Subscription() {}
};
using SubscriptionPtr = std::unique_ptr<Subscription>;
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
SubscriptionPtr subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const;
template <typename EntityType>
SubscriptionPtr subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(typeid(EntityType), handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
SubscriptionPtr subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
SubscriptionPtr subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
bool hasSubscription(std::type_index type) const;
bool hasSubscription(const UUID & id) const;
protected:
virtual std::optional<UUID> findImpl(std::type_index type, const String & name) const = 0;
virtual std::vector<UUID> findAllImpl(std::type_index type) const = 0;
virtual bool existsImpl(const UUID & id) const = 0;
virtual AccessEntityPtr readImpl(const UUID & id) const = 0;
virtual String readNameImpl(const UUID & id) const = 0;
virtual UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) = 0;
virtual void removeImpl(const UUID & id) = 0;
virtual void updateImpl(const UUID & id, const UpdateFunc & update_func) = 0;
virtual SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0;
virtual SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const = 0;
virtual bool hasSubscriptionImpl(const UUID & id) const = 0;
virtual bool hasSubscriptionImpl(std::type_index type) const = 0;
static UUID generateRandomID();
Poco::Logger * getLogger() const;
static String getTypeName(std::type_index type) { return IAccessEntity::getTypeName(type); }
[[noreturn]] void throwNotFound(const UUID & id) const;
[[noreturn]] void throwNotFound(std::type_index type, const String & name) const;
[[noreturn]] void throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) const;
[[noreturn]] void throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const;
[[noreturn]] void throwNameCollisionCannotInsert(std::type_index type, const String & name) const;
[[noreturn]] void throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const;
[[noreturn]] void throwReadonlyCannotInsert(std::type_index type, const String & name) const;
[[noreturn]] void throwReadonlyCannotUpdate(std::type_index type, const String & name) const;
[[noreturn]] void throwReadonlyCannotRemove(std::type_index type, const String & name) const;
using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>;
using Notifications = std::vector<Notification>;
static void notify(const Notifications & notifications);
private:
AccessEntityPtr tryReadBase(const UUID & id) const;
const String storage_name;
mutable std::atomic<Poco::Logger *> log = nullptr;
};
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::read(const UUID & id) const
{
auto entity = readImpl(id);
auto ptr = typeid_cast<std::shared_ptr<const EntityType>>(entity);
if (ptr)
return ptr;
throwBadCast(id, entity->getType(), entity->getFullName(), typeid(EntityType));
}
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::read(const String & name) const
{
return read<EntityType>(getID<EntityType>(name));
}
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const UUID & id) const
{
auto entity = tryReadBase(id);
if (!entity)
return nullptr;
return typeid_cast<std::shared_ptr<const EntityType>>(entity);
}
template <typename EntityType>
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const String & name) const
{
auto id = find<EntityType>(name);
return id ? tryRead<EntityType>(*id) : nullptr;
}
}

View File

@ -0,0 +1,358 @@
#include <Access/MemoryAccessStorage.h>
#include <ext/scope_guard.h>
#include <unordered_set>
namespace DB
{
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_)
: IAccessStorage(storage_name_), shared_ptr_to_this{std::make_shared<const MemoryAccessStorage *>(this)}
{
}
MemoryAccessStorage::~MemoryAccessStorage() {}
std::optional<UUID> MemoryAccessStorage::findImpl(std::type_index type, const String & name) const
{
std::lock_guard lock{mutex};
auto it = names.find({name, type});
if (it == names.end())
return {};
Entry & entry = *(it->second);
return entry.id;
}
std::vector<UUID> MemoryAccessStorage::findAllImpl(std::type_index type) const
{
std::lock_guard lock{mutex};
std::vector<UUID> result;
result.reserve(entries.size());
for (const auto & [id, entry] : entries)
if (entry.entity->isTypeOf(type))
result.emplace_back(id);
return result;
}
bool MemoryAccessStorage::existsImpl(const UUID & id) const
{
std::lock_guard lock{mutex};
return entries.count(id);
}
AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries.find(id);
if (it == entries.end())
throwNotFound(id);
const Entry & entry = it->second;
return entry.entity;
}
String MemoryAccessStorage::readNameImpl(const UUID & id) const
{
return readImpl(id)->getFullName();
}
UUID MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID();
std::lock_guard lock{mutex};
insertNoLock(generateRandomID(), new_entity, replace_if_exists, notifications);
return id;
}
void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications)
{
const String & name = new_entity->getFullName();
std::type_index type = new_entity->getType();
/// Check that we can insert.
auto it = entries.find(id);
if (it != entries.end())
{
const auto & existing_entry = it->second;
throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getFullName());
}
auto it2 = names.find({name, type});
if (it2 != names.end())
{
const auto & existing_entry = *(it2->second);
if (replace_if_exists)
removeNoLock(existing_entry.id, notifications);
else
throwNameCollisionCannotInsert(type, name);
}
/// Do insertion.
auto & entry = entries[id];
entry.id = id;
entry.entity = new_entity;
names[std::pair{name, type}] = &entry;
prepareNotifications(entry, false, notifications);
}
void MemoryAccessStorage::removeImpl(const UUID & id)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
removeNoLock(id, notifications);
}
void MemoryAccessStorage::removeNoLock(const UUID & id, Notifications & notifications)
{
auto it = entries.find(id);
if (it == entries.end())
throwNotFound(id);
Entry & entry = it->second;
const String & name = entry.entity->getFullName();
std::type_index type = entry.entity->getType();
prepareNotifications(entry, true, notifications);
/// Do removing.
names.erase({name, type});
entries.erase(it);
}
void MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
updateNoLock(id, update_func, notifications);
}
void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications)
{
auto it = entries.find(id);
if (it == entries.end())
throwNotFound(id);
Entry & entry = it->second;
auto old_entity = entry.entity;
auto new_entity = update_func(old_entity);
if (*new_entity == *old_entity)
return;
entry.entity = new_entity;
if (new_entity->getFullName() != old_entity->getFullName())
{
auto it2 = names.find({new_entity->getFullName(), new_entity->getType()});
if (it2 != names.end())
throwNameCollisionCannotRename(old_entity->getType(), old_entity->getFullName(), new_entity->getFullName());
names.erase({old_entity->getFullName(), old_entity->getType()});
names[std::pair{new_entity->getFullName(), new_entity->getType()}] = &entry;
}
prepareNotifications(entry, false, notifications);
}
void MemoryAccessStorage::setAll(const std::vector<AccessEntityPtr> & all_entities)
{
std::vector<std::pair<UUID, AccessEntityPtr>> entities_with_ids;
entities_with_ids.reserve(all_entities.size());
for (const auto & entity : all_entities)
entities_with_ids.emplace_back(generateRandomID(), entity);
setAll(entities_with_ids);
}
void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
setAllNoLock(all_entities, notifications);
}
void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications)
{
/// Get list of the currently used IDs. Later we will remove those of them which are not used anymore.
std::unordered_set<UUID> not_used_ids;
for (const auto & id_and_entry : entries)
not_used_ids.emplace(id_and_entry.first);
/// Remove conflicting entities.
for (const auto & [id, entity] : all_entities)
{
auto it = entries.find(id);
if (it != entries.end())
{
not_used_ids.erase(id); /// ID is used.
Entry & entry = it->second;
if (entry.entity->getType() != entity->getType())
{
removeNoLock(id, notifications);
continue;
}
}
auto it2 = names.find({entity->getFullName(), entity->getType()});
if (it2 != names.end())
{
Entry & entry = *(it2->second);
if (entry.id != id)
removeNoLock(id, notifications);
}
}
/// Remove entities which are not used anymore.
for (const auto & id : not_used_ids)
removeNoLock(id, notifications);
/// Insert or update entities.
for (const auto & [id, entity] : all_entities)
{
auto it = entries.find(id);
if (it != entries.end())
{
if (*(it->second.entity) != *entity)
{
const AccessEntityPtr & changed_entity = entity;
updateNoLock(id, [&changed_entity](const AccessEntityPtr &) { return changed_entity; }, notifications);
}
}
else
insertNoLock(id, entity, false, notifications);
}
}
void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, remove ? nullptr : entry.entity});
auto range = handlers_by_type.equal_range(entry.entity->getType());
for (auto it = range.first; it != range.second; ++it)
notifications.push_back({it->second, entry.id, remove ? nullptr : entry.entity});
}
IAccessStorage::SubscriptionPtr MemoryAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
{
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(
const MemoryAccessStorage & storage_,
std::type_index type_,
const OnChangedHandler & handler_)
: storage_weak(storage_.shared_ptr_to_this)
{
std::lock_guard lock{storage_.mutex};
handler_it = storage_.handlers_by_type.emplace(type_, handler_);
}
~SubscriptionImpl() override
{
auto storage = storage_weak.lock();
if (storage)
{
std::lock_guard lock{(*storage)->mutex};
(*storage)->handlers_by_type.erase(handler_it);
}
}
private:
std::weak_ptr<const MemoryAccessStorage *> storage_weak;
std::unordered_multimap<std::type_index, OnChangedHandler>::iterator handler_it;
};
return std::make_unique<SubscriptionImpl>(*this, type, handler);
}
IAccessStorage::SubscriptionPtr MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(
const MemoryAccessStorage & storage_,
const UUID & id_,
const OnChangedHandler & handler_)
: storage_weak(storage_.shared_ptr_to_this),
id(id_)
{
std::lock_guard lock{storage_.mutex};
auto it = storage_.entries.find(id);
if (it == storage_.entries.end())
{
storage_weak.reset();
return;
}
const Entry & entry = it->second;
handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler_);
}
~SubscriptionImpl() override
{
auto storage = storage_weak.lock();
if (storage)
{
std::lock_guard lock{(*storage)->mutex};
auto it = (*storage)->entries.find(id);
if (it != (*storage)->entries.end())
{
const Entry & entry = it->second;
entry.handlers_by_id.erase(handler_it);
}
}
}
private:
std::weak_ptr<const MemoryAccessStorage *> storage_weak;
UUID id;
std::list<OnChangedHandler>::iterator handler_it;
};
return std::make_unique<SubscriptionImpl>(*this, id, handler);
}
bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const
{
auto it = entries.find(id);
if (it != entries.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool MemoryAccessStorage::hasSubscriptionImpl(std::type_index type) const
{
auto range = handlers_by_type.equal_range(type);
return range.first != range.second;
}
}

View File

@ -0,0 +1,65 @@
#pragma once
#include <Access/IAccessStorage.h>
#include <list>
#include <memory>
#include <mutex>
#include <unordered_map>
namespace DB
{
/// Implementation of IAccessStorage which keeps all data in memory.
class MemoryAccessStorage : public IAccessStorage
{
public:
MemoryAccessStorage(const String & storage_name_ = "memory");
~MemoryAccessStorage() override;
/// Sets all entities at once.
void setAll(const std::vector<AccessEntityPtr> & all_entities);
void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities);
private:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override;
bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID & id) const override;
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override;
struct Entry
{
UUID id;
AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
};
void insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, Notifications & notifications);
void removeNoLock(const UUID & id, Notifications & notifications);
void updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications);
void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications);
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
using NameTypePair = std::pair<String, std::type_index>;
struct Hash
{
size_t operator()(const NameTypePair & key) const
{
return std::hash<String>{}(key.first) - std::hash<std::type_index>{}(key.second);
}
};
mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries; /// We want to search entries both by ID and by the pair of name and type.
std::unordered_map<NameTypePair, Entry *, Hash> names; /// and by the pair of name and type.
mutable std::unordered_multimap<std::type_index, OnChangedHandler> handlers_by_type;
std::shared_ptr<const MemoryAccessStorage *> shared_ptr_to_this; /// We need weak pointers to `this` to implement subscriptions.
};
}

View File

@ -0,0 +1,246 @@
#include <Access/MultipleAccessStorage.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ACCESS_ENTITY_NOT_FOUND;
extern const int ACCESS_ENTITY_FOUND_DUPLICATES;
}
namespace
{
template <typename StoragePtrT>
String joinStorageNames(const std::vector<StoragePtrT> & storages)
{
String result;
for (const auto & storage : storages)
{
if (!result.empty())
result += ", ";
result += storage->getStorageName();
}
return result;
}
}
MultipleAccessStorage::MultipleAccessStorage(
std::vector<std::unique_ptr<Storage>> nested_storages_, size_t index_of_nested_storage_for_insertion_)
: IAccessStorage(joinStorageNames(nested_storages_))
, nested_storages(std::move(nested_storages_))
, nested_storage_for_insertion(nested_storages[index_of_nested_storage_for_insertion_].get())
, ids_cache(512 /* cache size */)
{
}
MultipleAccessStorage::~MultipleAccessStorage()
{
}
std::vector<UUID> MultipleAccessStorage::findMultiple(std::type_index type, const String & name) const
{
std::vector<UUID> ids;
for (const auto & nested_storage : nested_storages)
{
auto id = nested_storage->find(type, name);
if (id)
{
std::lock_guard lock{ids_cache_mutex};
ids_cache.set(*id, std::make_shared<Storage *>(nested_storage.get()));
ids.push_back(*id);
}
}
return ids;
}
std::optional<UUID> MultipleAccessStorage::findImpl(std::type_index type, const String & name) const
{
auto ids = findMultiple(type, name);
if (ids.empty())
return {};
if (ids.size() == 1)
return ids[0];
std::vector<const Storage *> storages_with_duplicates;
for (const auto & id : ids)
{
auto * storage = findStorage(id);
if (storage)
storages_with_duplicates.push_back(storage);
}
throw Exception(
"Found " + getTypeName(type) + " " + backQuote(name) + " in " + std::to_string(ids.size())
+ " storages: " + joinStorageNames(storages_with_duplicates),
ErrorCodes::ACCESS_ENTITY_FOUND_DUPLICATES);
}
std::vector<UUID> MultipleAccessStorage::findAllImpl(std::type_index type) const
{
std::vector<UUID> all_ids;
for (const auto & nested_storage : nested_storages)
{
auto ids = nested_storage->findAll(type);
all_ids.insert(all_ids.end(), std::make_move_iterator(ids.begin()), std::make_move_iterator(ids.end()));
}
return all_ids;
}
bool MultipleAccessStorage::existsImpl(const UUID & id) const
{
return findStorage(id) != nullptr;
}
IAccessStorage * MultipleAccessStorage::findStorage(const UUID & id)
{
{
std::lock_guard lock{ids_cache_mutex};
auto from_cache = ids_cache.get(id);
if (from_cache)
{
auto * storage = *from_cache;
if (storage->exists(id))
return storage;
}
}
for (const auto & nested_storage : nested_storages)
{
if (nested_storage->exists(id))
{
std::lock_guard lock{ids_cache_mutex};
ids_cache.set(id, std::make_shared<Storage *>(nested_storage.get()));
return nested_storage.get();
}
}
return nullptr;
}
const IAccessStorage * MultipleAccessStorage::findStorage(const UUID & id) const
{
return const_cast<MultipleAccessStorage *>(this)->findStorage(id);
}
IAccessStorage & MultipleAccessStorage::getStorage(const UUID & id)
{
auto * storage = findStorage(id);
if (storage)
return *storage;
throwNotFound(id);
}
const IAccessStorage & MultipleAccessStorage::getStorage(const UUID & id) const
{
return const_cast<MultipleAccessStorage *>(this)->getStorage(id);
}
AccessEntityPtr MultipleAccessStorage::readImpl(const UUID & id) const
{
return getStorage(id).read(id);
}
String MultipleAccessStorage::readNameImpl(const UUID & id) const
{
return getStorage(id).readName(id);
}
UUID MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists)
{
auto id = replace_if_exists ? nested_storage_for_insertion->insertOrReplace(entity) : nested_storage_for_insertion->insert(entity);
std::lock_guard lock{ids_cache_mutex};
ids_cache.set(id, std::make_shared<Storage *>(nested_storage_for_insertion));
return id;
}
void MultipleAccessStorage::removeImpl(const UUID & id)
{
getStorage(id).remove(id);
}
void MultipleAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func)
{
getStorage(id).update(id, update_func);
}
IAccessStorage::SubscriptionPtr MultipleAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
auto storage = findStorage(id);
if (!storage)
return nullptr;
return storage->subscribeForChanges(id, handler);
}
IAccessStorage::SubscriptionPtr MultipleAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
{
std::vector<SubscriptionPtr> subscriptions;
for (const auto & nested_storage : nested_storages)
{
auto subscription = nested_storage->subscribeForChanges(type, handler);
if (subscription)
subscriptions.emplace_back(std::move(subscription));
}
if (subscriptions.empty())
return nullptr;
if (subscriptions.size() == 1)
return std::move(subscriptions[0]);
class SubscriptionImpl : public Subscription
{
public:
SubscriptionImpl(std::vector<SubscriptionPtr> subscriptions_)
: subscriptions(std::move(subscriptions_)) {}
private:
std::vector<SubscriptionPtr> subscriptions;
};
return std::make_unique<SubscriptionImpl>(std::move(subscriptions));
}
bool MultipleAccessStorage::hasSubscriptionImpl(const UUID & id) const
{
for (const auto & nested_storage : nested_storages)
{
if (nested_storage->hasSubscription(id))
return true;
}
return false;
}
bool MultipleAccessStorage::hasSubscriptionImpl(std::type_index type) const
{
for (const auto & nested_storage : nested_storages)
{
if (nested_storage->hasSubscription(type))
return true;
}
return false;
}
}

View File

@ -0,0 +1,53 @@
#pragma once
#include <Access/IAccessStorage.h>
#include <Common/LRUCache.h>
#include <mutex>
namespace DB
{
/// Implementation of IAccessStorage which contains multiple nested storages.
class MultipleAccessStorage : public IAccessStorage
{
public:
using Storage = IAccessStorage;
MultipleAccessStorage(std::vector<std::unique_ptr<Storage>> nested_storages_, size_t index_of_nested_storage_for_insertion_ = 0);
~MultipleAccessStorage() override;
std::vector<UUID> findMultiple(std::type_index type, const String & name) const;
template <typename EntityType>
std::vector<UUID> findMultiple(const String & name) const { return findMultiple(EntityType::TYPE, name); }
const Storage * findStorage(const UUID & id) const;
Storage * findStorage(const UUID & id);
const Storage & getStorage(const UUID & id) const;
Storage & getStorage(const UUID & id);
Storage & getStorageByIndex(size_t i) { return *(nested_storages[i]); }
const Storage & getStorageByIndex(size_t i) const { return *(nested_storages[i]); }
protected:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override;
bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID &id) const override;
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override;
private:
std::vector<std::unique_ptr<Storage>> nested_storages;
IAccessStorage * nested_storage_for_insertion;
mutable LRUCache<UUID, Storage *> ids_cache;
mutable std::mutex ids_cache_mutex;
};
}

46
dbms/src/Access/Quota.cpp Normal file
View File

@ -0,0 +1,46 @@
#include <Access/Quota.h>
#include <boost/range/algorithm/equal.hpp>
#include <boost/range/algorithm/fill.hpp>
namespace DB
{
Quota::Limits::Limits()
{
boost::range::fill(max, 0);
}
bool operator ==(const Quota::Limits & lhs, const Quota::Limits & rhs)
{
return boost::range::equal(lhs.max, rhs.max) && (lhs.duration == rhs.duration)
&& (lhs.randomize_interval == rhs.randomize_interval);
}
bool Quota::equal(const IAccessEntity & other) const
{
if (!IAccessEntity::equal(other))
return false;
const auto & other_quota = typeid_cast<const Quota &>(other);
return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles)
&& (all_roles == other_quota.all_roles) && (except_roles == other_quota.except_roles);
}
const char * Quota::resourceTypeToColumnName(ResourceType resource_type)
{
switch (resource_type)
{
case Quota::QUERIES: return "queries";
case Quota::ERRORS: return "errors";
case Quota::RESULT_ROWS: return "result_rows";
case Quota::RESULT_BYTES: return "result_bytes";
case Quota::READ_ROWS: return "read_rows";
case Quota::READ_BYTES: return "read_bytes";
case Quota::EXECUTION_TIME: return "execution_time";
}
__builtin_unreachable();
}
}

141
dbms/src/Access/Quota.h Normal file
View File

@ -0,0 +1,141 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <chrono>
namespace DB
{
/** Quota for resources consumption for specific interval.
* Used to limit resource usage by user.
* Quota is applied "softly" - could be slightly exceed, because it is checked usually only on each block of processed data.
* Accumulated values are not persisted and are lost on server restart.
* Quota is local to server,
* but for distributed queries, accumulated values for read rows and bytes
* are collected from all participating servers and accumulated locally.
*/
struct Quota : public IAccessEntity
{
enum ResourceType
{
QUERIES, /// Number of queries.
ERRORS, /// Number of queries with exceptions.
RESULT_ROWS, /// Number of rows returned as result.
RESULT_BYTES, /// Number of bytes returned as result.
READ_ROWS, /// Number of rows read from tables.
READ_BYTES, /// Number of bytes read from tables.
EXECUTION_TIME, /// Total amount of query execution time in nanoseconds.
};
static constexpr size_t MAX_RESOURCE_TYPE = 7;
using ResourceAmount = UInt64;
static constexpr ResourceAmount UNLIMITED = 0; /// 0 means unlimited.
/// Amount of resources available to consume for each duration.
struct Limits
{
ResourceAmount max[MAX_RESOURCE_TYPE];
std::chrono::seconds duration = std::chrono::seconds::zero();
/// Intervals can be randomized (to avoid DoS if intervals for many users end at one time).
bool randomize_interval = false;
Limits();
friend bool operator ==(const Limits & lhs, const Limits & rhs);
friend bool operator !=(const Limits & lhs, const Limits & rhs) { return !(lhs == rhs); }
};
std::vector<Limits> all_limits;
/// Key to share quota consumption.
/// Users with the same key share the same amount of resource.
enum class KeyType
{
NONE, /// All users share the same quota.
USER_NAME, /// Connections with the same user name share the same quota.
IP_ADDRESS, /// Connections from the same IP share the same quota.
CLIENT_KEY, /// Client should explicitly supply a key to use.
CLIENT_KEY_OR_USER_NAME, /// Same as CLIENT_KEY, but use USER_NAME if the client doesn't supply a key.
CLIENT_KEY_OR_IP_ADDRESS, /// Same as CLIENT_KEY, but use IP_ADDRESS if the client doesn't supply a key.
};
static constexpr size_t MAX_KEY_TYPE = 6;
KeyType key_type = KeyType::NONE;
/// Which roles or users should use this quota.
Strings roles;
bool all_roles = false;
Strings except_roles;
bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); }
static const char * getNameOfResourceType(ResourceType resource_type);
static const char * resourceTypeToKeyword(ResourceType resource_type);
static const char * resourceTypeToColumnName(ResourceType resource_type);
static const char * getNameOfKeyType(KeyType key_type);
static double executionTimeToSeconds(ResourceAmount ns);
static ResourceAmount secondsToExecutionTime(double s);
};
inline const char * Quota::getNameOfResourceType(ResourceType resource_type)
{
switch (resource_type)
{
case Quota::QUERIES: return "queries";
case Quota::ERRORS: return "errors";
case Quota::RESULT_ROWS: return "result rows";
case Quota::RESULT_BYTES: return "result bytes";
case Quota::READ_ROWS: return "read rows";
case Quota::READ_BYTES: return "read bytes";
case Quota::EXECUTION_TIME: return "execution time";
}
__builtin_unreachable();
}
inline const char * Quota::resourceTypeToKeyword(ResourceType resource_type)
{
switch (resource_type)
{
case Quota::QUERIES: return "QUERIES";
case Quota::ERRORS: return "ERRORS";
case Quota::RESULT_ROWS: return "RESULT ROWS";
case Quota::RESULT_BYTES: return "RESULT BYTES";
case Quota::READ_ROWS: return "READ ROWS";
case Quota::READ_BYTES: return "READ BYTES";
case Quota::EXECUTION_TIME: return "EXECUTION TIME";
}
__builtin_unreachable();
}
inline const char * Quota::getNameOfKeyType(KeyType key_type)
{
switch (key_type)
{
case KeyType::NONE: return "none";
case KeyType::USER_NAME: return "user name";
case KeyType::IP_ADDRESS: return "ip address";
case KeyType::CLIENT_KEY: return "client key";
case KeyType::CLIENT_KEY_OR_USER_NAME: return "client key or user name";
case KeyType::CLIENT_KEY_OR_IP_ADDRESS: return "client key or ip address";
}
__builtin_unreachable();
}
inline double Quota::executionTimeToSeconds(ResourceAmount ns)
{
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::nanoseconds{ns}).count();
}
inline Quota::ResourceAmount Quota::secondsToExecutionTime(double s)
{
return std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::duration<double>(s)).count();
}
using QuotaPtr = std::shared_ptr<const Quota>;
}

View File

@ -0,0 +1,264 @@
#include <Access/QuotaContext.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <ext/chrono_io.h>
#include <ext/range.h>
#include <boost/range/algorithm/fill.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int QUOTA_EXPIRED;
}
struct QuotaContext::Impl
{
[[noreturn]] static void throwQuotaExceed(
const String & user_name,
const String & quota_name,
ResourceType resource_type,
ResourceAmount used,
ResourceAmount max,
std::chrono::seconds duration,
std::chrono::system_clock::time_point end_of_interval)
{
std::function<String(UInt64)> amount_to_string = [](UInt64 amount) { return std::to_string(amount); };
if (resource_type == Quota::EXECUTION_TIME)
amount_to_string = [&](UInt64 amount) { return ext::to_string(std::chrono::nanoseconds(amount)); };
throw Exception(
"Quota for user " + backQuote(user_name) + " for " + ext::to_string(duration) + " has been exceeded: "
+ Quota::getNameOfResourceType(resource_type) + " = " + amount_to_string(used) + "/" + amount_to_string(max) + ". "
+ "Interval will end at " + ext::to_string(end_of_interval) + ". " + "Name of quota template: " + backQuote(quota_name),
ErrorCodes::QUOTA_EXPIRED);
}
static std::chrono::system_clock::time_point getEndOfInterval(
const Interval & interval, std::chrono::system_clock::time_point current_time, bool * counters_were_reset = nullptr)
{
auto & end_of_interval = interval.end_of_interval;
auto end_loaded = end_of_interval.load();
auto end = std::chrono::system_clock::time_point{end_loaded};
if (current_time < end)
{
if (counters_were_reset)
*counters_were_reset = false;
return end;
}
const auto duration = interval.duration;
do
{
end = end + (current_time - end + duration) / duration * duration;
if (end_of_interval.compare_exchange_strong(end_loaded, end.time_since_epoch()))
{
boost::range::fill(interval.used, 0);
break;
}
end = std::chrono::system_clock::time_point{end_loaded};
}
while (current_time >= end);
if (counters_were_reset)
*counters_were_reset = true;
return end;
}
static void used(
const String & user_name,
const Intervals & intervals,
ResourceType resource_type,
ResourceAmount amount,
std::chrono::system_clock::time_point current_time,
bool check_exceeded)
{
for (const auto & interval : intervals.intervals)
{
ResourceAmount used = (interval.used[resource_type] += amount);
ResourceAmount max = interval.max[resource_type];
if (max == Quota::UNLIMITED)
continue;
if (used > max)
{
bool counters_were_reset = false;
auto end_of_interval = getEndOfInterval(interval, current_time, &counters_were_reset);
if (counters_were_reset)
{
used = (interval.used[resource_type] += amount);
if ((used > max) && check_exceeded)
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
}
else if (check_exceeded)
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
}
}
}
static void checkExceeded(
const String & user_name,
const Intervals & intervals,
ResourceType resource_type,
std::chrono::system_clock::time_point current_time)
{
for (const auto & interval : intervals.intervals)
{
ResourceAmount used = interval.used[resource_type];
ResourceAmount max = interval.max[resource_type];
if (max == Quota::UNLIMITED)
continue;
if (used > max)
{
bool used_counters_reset = false;
std::chrono::system_clock::time_point end_of_interval = getEndOfInterval(interval, current_time, &used_counters_reset);
if (!used_counters_reset)
throwQuotaExceed(user_name, intervals.quota_name, resource_type, used, max, interval.duration, end_of_interval);
}
}
}
static void checkExceeded(
const String & user_name,
const Intervals & intervals,
std::chrono::system_clock::time_point current_time)
{
for (auto resource_type : ext::range_with_static_cast<Quota::ResourceType>(Quota::MAX_RESOURCE_TYPE))
checkExceeded(user_name, intervals, resource_type, current_time);
}
};
QuotaContext::Interval & QuotaContext::Interval::operator =(const Interval & src)
{
randomize_interval = src.randomize_interval;
duration = src.duration;
end_of_interval.store(src.end_of_interval.load());
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
max[resource_type] = src.max[resource_type];
used[resource_type].store(src.used[resource_type].load());
}
return *this;
}
QuotaUsageInfo QuotaContext::Intervals::getUsageInfo(std::chrono::system_clock::time_point current_time) const
{
QuotaUsageInfo info;
info.quota_id = quota_id;
info.quota_name = quota_name;
info.quota_key = quota_key;
info.intervals.reserve(intervals.size());
for (const auto & in : intervals)
{
info.intervals.push_back({});
auto & out = info.intervals.back();
out.duration = in.duration;
out.randomize_interval = in.randomize_interval;
out.end_of_interval = Impl::getEndOfInterval(in, current_time);
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
out.max[resource_type] = in.max[resource_type];
out.used[resource_type] = in.used[resource_type];
}
}
return info;
}
QuotaContext::QuotaContext()
: atomic_intervals(std::make_shared<Intervals>()) /// Unlimited quota.
{
}
QuotaContext::QuotaContext(
const String & user_name_,
const Poco::Net::IPAddress & address_,
const String & client_key_)
: user_name(user_name_), address(address_), client_key(client_key_)
{
}
QuotaContext::~QuotaContext() = default;
void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded)
{
used({resource_type, amount}, check_exceeded);
}
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded);
}
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded);
}
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource3.first, resource3.second, current_time, check_exceeded);
}
void QuotaContext::used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
auto current_time = std::chrono::system_clock::now();
for (const auto & resource : resources)
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded);
}
void QuotaContext::checkExceeded()
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
Impl::checkExceeded(user_name, *intervals_ptr, std::chrono::system_clock::now());
}
void QuotaContext::checkExceeded(ResourceType resource_type)
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
Impl::checkExceeded(user_name, *intervals_ptr, resource_type, std::chrono::system_clock::now());
}
QuotaUsageInfo QuotaContext::getUsageInfo() const
{
auto intervals_ptr = std::atomic_load(&atomic_intervals);
return intervals_ptr->getUsageInfo(std::chrono::system_clock::now());
}
QuotaUsageInfo::QuotaUsageInfo() : quota_id(UUID(UInt128(0)))
{
}
QuotaUsageInfo::Interval::Interval()
{
boost::range::fill(used, 0);
boost::range::fill(max, 0);
}
}

View File

@ -0,0 +1,110 @@
#pragma once
#include <Access/Quota.h>
#include <Core/UUID.h>
#include <Poco/Net/IPAddress.h>
#include <ext/shared_ptr_helper.h>
#include <boost/noncopyable.hpp>
#include <atomic>
#include <chrono>
#include <memory>
namespace DB
{
struct QuotaUsageInfo;
/// Instances of `QuotaContext` are used to track resource consumption.
class QuotaContext : public boost::noncopyable
{
public:
using ResourceType = Quota::ResourceType;
using ResourceAmount = Quota::ResourceAmount;
/// Default constructors makes an unlimited quota.
QuotaContext();
~QuotaContext();
/// Tracks resource consumption. If the quota exceeded and `check_exceeded == true`, throws an exception.
void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true);
void used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded = true);
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded = true);
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded = true);
void used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded = true);
/// Checks if the quota exceeded. If so, throws an exception.
void checkExceeded();
void checkExceeded(ResourceType resource_type);
/// Returns the information about this quota context.
QuotaUsageInfo getUsageInfo() const;
private:
friend class QuotaContextFactory;
friend struct ext::shared_ptr_helper<QuotaContext>;
/// Instances of this class are created by QuotaContextFactory.
QuotaContext(const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_);
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
struct Interval
{
mutable std::atomic<ResourceAmount> used[MAX_RESOURCE_TYPE];
ResourceAmount max[MAX_RESOURCE_TYPE];
std::chrono::seconds duration;
bool randomize_interval;
mutable std::atomic<std::chrono::system_clock::duration> end_of_interval;
Interval() {}
Interval(const Interval & src) { *this = src; }
Interval & operator =(const Interval & src);
};
struct Intervals
{
std::vector<Interval> intervals;
UUID quota_id;
String quota_name;
String quota_key;
QuotaUsageInfo getUsageInfo(std::chrono::system_clock::time_point current_time) const;
};
struct Impl;
const String user_name;
const Poco::Net::IPAddress address;
const String client_key;
std::shared_ptr<const Intervals> atomic_intervals; /// atomically changed by QuotaUsageManager
};
using QuotaContextPtr = std::shared_ptr<QuotaContext>;
/// The information about a quota context.
struct QuotaUsageInfo
{
using ResourceType = Quota::ResourceType;
using ResourceAmount = Quota::ResourceAmount;
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
struct Interval
{
ResourceAmount used[MAX_RESOURCE_TYPE];
ResourceAmount max[MAX_RESOURCE_TYPE];
std::chrono::seconds duration = std::chrono::seconds::zero();
bool randomize_interval = false;
std::chrono::system_clock::time_point end_of_interval;
Interval();
};
std::vector<Interval> intervals;
UUID quota_id;
String quota_name;
String quota_key;
QuotaUsageInfo();
};
}

View File

@ -0,0 +1,299 @@
#include <Access/QuotaContext.h>
#include <Access/QuotaContextFactory.h>
#include <Access/AccessControlManager.h>
#include <Common/Exception.h>
#include <Common/thread_local_rng.h>
#include <ext/range.h>
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/algorithm/lower_bound.hpp>
#include <boost/range/algorithm/stable_sort.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int QUOTA_REQUIRES_CLIENT_KEY;
}
namespace
{
std::chrono::system_clock::duration randomDuration(std::chrono::seconds max)
{
auto count = std::chrono::duration_cast<std::chrono::system_clock::duration>(max).count();
std::uniform_int_distribution<Int64> distribution{0, count - 1};
return std::chrono::system_clock::duration(distribution(thread_local_rng));
}
}
void QuotaContextFactory::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUID & quota_id_)
{
quota = quota_;
quota_id = quota_id_;
boost::range::copy(quota->roles, std::inserter(roles, roles.end()));
all_roles = quota->all_roles;
boost::range::copy(quota->except_roles, std::inserter(except_roles, except_roles.end()));
rebuildAllIntervals();
}
bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const
{
if (roles.count(context.user_name))
return true;
if (all_roles && !except_roles.count(context.user_name))
return true;
return false;
}
String QuotaContextFactory::QuotaInfo::calculateKey(const QuotaContext & context) const
{
using KeyType = Quota::KeyType;
switch (quota->key_type)
{
case KeyType::NONE:
return "";
case KeyType::USER_NAME:
return context.user_name;
case KeyType::IP_ADDRESS:
return context.address.toString();
case KeyType::CLIENT_KEY:
{
if (!context.client_key.empty())
return context.client_key;
throw Exception(
"Quota " + quota->getName() + " (for user " + context.user_name + ") requires a client supplied key.",
ErrorCodes::QUOTA_REQUIRES_CLIENT_KEY);
}
case KeyType::CLIENT_KEY_OR_USER_NAME:
{
if (!context.client_key.empty())
return context.client_key;
return context.user_name;
}
case KeyType::CLIENT_KEY_OR_IP_ADDRESS:
{
if (!context.client_key.empty())
return context.client_key;
return context.address.toString();
}
}
__builtin_unreachable();
}
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key)
{
auto it = key_to_intervals.find(key);
if (it != key_to_intervals.end())
return it->second;
return rebuildIntervals(key);
}
void QuotaContextFactory::QuotaInfo::rebuildAllIntervals()
{
for (const String & key : key_to_intervals | boost::adaptors::map_keys)
rebuildIntervals(key);
}
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key)
{
auto new_intervals = std::make_shared<Intervals>();
new_intervals->quota_name = quota->getName();
new_intervals->quota_id = quota_id;
new_intervals->quota_key = key;
auto & intervals = new_intervals->intervals;
intervals.reserve(quota->all_limits.size());
constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
for (const auto & limits : quota->all_limits)
{
intervals.emplace_back();
auto & interval = intervals.back();
interval.duration = limits.duration;
std::chrono::system_clock::time_point end_of_interval{};
interval.randomize_interval = limits.randomize_interval;
if (limits.randomize_interval)
end_of_interval += randomDuration(limits.duration);
interval.end_of_interval = end_of_interval.time_since_epoch();
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
interval.max[resource_type] = limits.max[resource_type];
interval.used[resource_type] = 0;
}
}
/// Order intervals by durations from largest to smallest.
/// To report first about largest interval on what quota was exceeded.
struct GreaterByDuration
{
bool operator()(const Interval & lhs, const Interval & rhs) const { return lhs.duration > rhs.duration; }
};
boost::range::stable_sort(intervals, GreaterByDuration{});
auto it = key_to_intervals.find(key);
if (it == key_to_intervals.end())
{
/// Just put new intervals into the map.
key_to_intervals.try_emplace(key, new_intervals);
}
else
{
/// We need to keep usage information from the old intervals.
const auto & old_intervals = it->second->intervals;
for (auto & new_interval : new_intervals->intervals)
{
/// Check if an interval with the same duration is already in use.
auto lower_bound = boost::range::lower_bound(old_intervals, new_interval, GreaterByDuration{});
if ((lower_bound == old_intervals.end()) || (lower_bound->duration != new_interval.duration))
continue;
/// Found an interval with the same duration, we need to copy its usage information to `result`.
auto & current_interval = *lower_bound;
for (auto resource_type : ext::range(MAX_RESOURCE_TYPE))
{
new_interval.used[resource_type].store(current_interval.used[resource_type].load());
new_interval.end_of_interval.store(current_interval.end_of_interval.load());
}
}
it->second = new_intervals;
}
return new_intervals;
}
QuotaContextFactory::QuotaContextFactory(const AccessControlManager & access_control_manager_)
: access_control_manager(access_control_manager_)
{
}
QuotaContextFactory::~QuotaContextFactory()
{
}
std::shared_ptr<QuotaContext> QuotaContextFactory::createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key)
{
std::lock_guard lock{mutex};
ensureAllQuotasRead();
auto context = ext::shared_ptr_helper<QuotaContext>::create(user_name, address, client_key);
contexts.push_back(context);
chooseQuotaForContext(context);
return context;
}
void QuotaContextFactory::ensureAllQuotasRead()
{
/// `mutex` is already locked.
if (all_quotas_read)
return;
all_quotas_read = true;
subscription = access_control_manager.subscribeForChanges<Quota>(
[&](const UUID & id, const AccessEntityPtr & entity)
{
if (entity)
quotaAddedOrChanged(id, typeid_cast<QuotaPtr>(entity));
else
quotaRemoved(id);
});
for (const UUID & quota_id : access_control_manager.findAll<Quota>())
{
auto quota = access_control_manager.tryRead<Quota>(quota_id);
if (quota)
all_quotas.emplace(quota_id, QuotaInfo(quota, quota_id));
}
}
void QuotaContextFactory::quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr<const Quota> & new_quota)
{
std::lock_guard lock{mutex};
auto it = all_quotas.find(quota_id);
if (it == all_quotas.end())
{
it = all_quotas.emplace(quota_id, QuotaInfo(new_quota, quota_id)).first;
}
else
{
if (it->second.quota == new_quota)
return;
}
auto & info = it->second;
info.setQuota(new_quota, quota_id);
chooseQuotaForAllContexts();
}
void QuotaContextFactory::quotaRemoved(const UUID & quota_id)
{
std::lock_guard lock{mutex};
all_quotas.erase(quota_id);
chooseQuotaForAllContexts();
}
void QuotaContextFactory::chooseQuotaForAllContexts()
{
/// `mutex` is already locked.
boost::range::remove_erase_if(
contexts,
[&](const std::weak_ptr<QuotaContext> & weak)
{
auto context = weak.lock();
if (!context)
return true; // remove from the `contexts` list.
chooseQuotaForContext(context);
return false; // keep in the `contexts` list.
});
}
void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context)
{
/// `mutex` is already locked.
std::shared_ptr<const Intervals> intervals;
for (auto & info : all_quotas | boost::adaptors::map_values)
{
if (info.canUseWithContext(*context))
{
String key = info.calculateKey(*context);
intervals = info.getOrBuildIntervals(key);
break;
}
}
if (!intervals)
intervals = std::make_shared<Intervals>(); /// No quota == no limits.
std::atomic_store(&context->atomic_intervals, intervals);
}
std::vector<QuotaUsageInfo> QuotaContextFactory::getUsageInfo() const
{
std::lock_guard lock{mutex};
std::vector<QuotaUsageInfo> all_infos;
auto current_time = std::chrono::system_clock::now();
for (const auto & info : all_quotas | boost::adaptors::map_values)
{
for (const auto & intervals : info.key_to_intervals | boost::adaptors::map_values)
all_infos.push_back(intervals->getUsageInfo(current_time));
}
return all_infos;
}
}

View File

@ -0,0 +1,62 @@
#pragma once
#include <Access/QuotaContext.h>
#include <Access/IAccessStorage.h>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
namespace DB
{
class AccessControlManager;
/// Stores information how much amount of resources have been consumed and how much are left.
class QuotaContextFactory
{
public:
QuotaContextFactory(const AccessControlManager & access_control_manager_);
~QuotaContextFactory();
QuotaContextPtr createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key);
std::vector<QuotaUsageInfo> getUsageInfo() const;
private:
using Interval = QuotaContext::Interval;
using Intervals = QuotaContext::Intervals;
struct QuotaInfo
{
QuotaInfo(const QuotaPtr & quota_, const UUID & quota_id_) { setQuota(quota_, quota_id_); }
void setQuota(const QuotaPtr & quota_, const UUID & quota_id_);
bool canUseWithContext(const QuotaContext & context) const;
String calculateKey(const QuotaContext & context) const;
std::shared_ptr<const Intervals> getOrBuildIntervals(const String & key);
std::shared_ptr<const Intervals> rebuildIntervals(const String & key);
void rebuildAllIntervals();
QuotaPtr quota;
UUID quota_id;
std::unordered_set<String> roles;
bool all_roles = false;
std::unordered_set<String> except_roles;
std::unordered_map<String /* quota key */, std::shared_ptr<const Intervals>> key_to_intervals;
};
void ensureAllQuotasRead();
void quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr<const Quota> & new_quota);
void quotaRemoved(const UUID & quota_id);
void chooseQuotaForAllContexts();
void chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context);
const AccessControlManager & access_control_manager;
mutable std::mutex mutex;
std::unordered_map<UUID /* quota id */, QuotaInfo> all_quotas;
bool all_quotas_read = false;
IAccessStorage::SubscriptionPtr subscription;
std::vector<std::weak_ptr<QuotaContext>> contexts;
};
}

View File

@ -0,0 +1,207 @@
#include <Access/UsersConfigAccessStorage.h>
#include <Access/Quota.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/quoteString.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Poco/MD5Engine.h>
#include <cstring>
namespace DB
{
namespace
{
char getTypeChar(std::type_index type)
{
if (type == typeid(Quota))
return 'Q';
return 0;
}
UUID generateID(std::type_index type, const String & name)
{
Poco::MD5Engine md5;
md5.update(name);
char type_storage_chars[] = " USRSXML";
type_storage_chars[0] = getTypeChar(type);
md5.update(type_storage_chars, strlen(type_storage_chars));
UUID result;
memcpy(&result, md5.digest().data(), md5.digestLength());
return result;
}
UUID generateID(const IAccessEntity & entity) { return generateID(entity.getType(), entity.getFullName()); }
QuotaPtr parseQuota(const Poco::Util::AbstractConfiguration & config, const String & quota_name, const Strings & user_names)
{
auto quota = std::make_shared<Quota>();
quota->setName(quota_name);
using KeyType = Quota::KeyType;
String quota_config = "quotas." + quota_name;
if (config.has(quota_config + ".keyed_by_ip"))
quota->key_type = KeyType::IP_ADDRESS;
else if (config.has(quota_config + ".keyed"))
quota->key_type = KeyType::CLIENT_KEY_OR_USER_NAME;
else
quota->key_type = KeyType::USER_NAME;
Poco::Util::AbstractConfiguration::Keys interval_keys;
config.keys(quota_config, interval_keys);
for (const String & interval_key : interval_keys)
{
if (!startsWith(interval_key, "interval"))
continue;
String interval_config = quota_config + "." + interval_key;
std::chrono::seconds duration{config.getInt(interval_config + ".duration", 0)};
if (duration.count() <= 0) /// Skip quotas with non-positive duration.
continue;
quota->all_limits.emplace_back();
auto & limits = quota->all_limits.back();
limits.duration = duration;
limits.randomize_interval = config.getBool(interval_config + ".randomize", false);
using ResourceType = Quota::ResourceType;
limits.max[ResourceType::QUERIES] = config.getUInt64(interval_config + ".queries", Quota::UNLIMITED);
limits.max[ResourceType::ERRORS] = config.getUInt64(interval_config + ".errors", Quota::UNLIMITED);
limits.max[ResourceType::RESULT_ROWS] = config.getUInt64(interval_config + ".result_rows", Quota::UNLIMITED);
limits.max[ResourceType::RESULT_BYTES] = config.getUInt64(interval_config + ".result_bytes", Quota::UNLIMITED);
limits.max[ResourceType::READ_ROWS] = config.getUInt64(interval_config + ".read_rows", Quota::UNLIMITED);
limits.max[ResourceType::READ_BYTES] = config.getUInt64(interval_config + ".read_bytes", Quota::UNLIMITED);
limits.max[ResourceType::EXECUTION_TIME] = Quota::secondsToExecutionTime(config.getUInt64(interval_config + ".execution_time", Quota::UNLIMITED));
}
quota->roles = user_names;
return quota;
}
std::vector<AccessEntityPtr> parseQuotas(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log)
{
Poco::Util::AbstractConfiguration::Keys user_names;
config.keys("users", user_names);
std::unordered_map<String, Strings> quota_to_user_names;
for (const auto & user_name : user_names)
{
if (config.has("users." + user_name + ".quota"))
quota_to_user_names[config.getString("users." + user_name + ".quota")].push_back(user_name);
}
Poco::Util::AbstractConfiguration::Keys quota_names;
config.keys("quotas", quota_names);
std::vector<AccessEntityPtr> quotas;
quotas.reserve(quota_names.size());
for (const auto & quota_name : quota_names)
{
try
{
auto it = quota_to_user_names.find(quota_name);
const Strings quota_users = (it != quota_to_user_names.end()) ? std::move(it->second) : Strings{};
quotas.push_back(parseQuota(config, quota_name, quota_users));
}
catch (...)
{
tryLogCurrentException(log, "Could not parse quota " + backQuote(quota_name));
}
}
return quotas;
}
}
UsersConfigAccessStorage::UsersConfigAccessStorage() : IAccessStorage("users.xml")
{
}
UsersConfigAccessStorage::~UsersConfigAccessStorage() {}
void UsersConfigAccessStorage::loadFromConfig(const Poco::Util::AbstractConfiguration & config)
{
std::vector<std::pair<UUID, AccessEntityPtr>> all_entities;
for (const auto & entity : parseQuotas(config, getLogger()))
all_entities.emplace_back(generateID(*entity), entity);
memory_storage.setAll(all_entities);
}
std::optional<UUID> UsersConfigAccessStorage::findImpl(std::type_index type, const String & name) const
{
return memory_storage.find(type, name);
}
std::vector<UUID> UsersConfigAccessStorage::findAllImpl(std::type_index type) const
{
return memory_storage.findAll(type);
}
bool UsersConfigAccessStorage::existsImpl(const UUID & id) const
{
return memory_storage.exists(id);
}
AccessEntityPtr UsersConfigAccessStorage::readImpl(const UUID & id) const
{
return memory_storage.read(id);
}
String UsersConfigAccessStorage::readNameImpl(const UUID & id) const
{
return memory_storage.readName(id);
}
UUID UsersConfigAccessStorage::insertImpl(const AccessEntityPtr & entity, bool)
{
throwReadonlyCannotInsert(entity->getType(), entity->getFullName());
}
void UsersConfigAccessStorage::removeImpl(const UUID & id)
{
auto entity = read(id);
throwReadonlyCannotRemove(entity->getType(), entity->getFullName());
}
void UsersConfigAccessStorage::updateImpl(const UUID & id, const UpdateFunc &)
{
auto entity = read(id);
throwReadonlyCannotUpdate(entity->getType(), entity->getFullName());
}
IAccessStorage::SubscriptionPtr UsersConfigAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(id, handler);
}
IAccessStorage::SubscriptionPtr UsersConfigAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(type, handler);
}
bool UsersConfigAccessStorage::hasSubscriptionImpl(const UUID & id) const
{
return memory_storage.hasSubscription(id);
}
bool UsersConfigAccessStorage::hasSubscriptionImpl(std::type_index type) const
{
return memory_storage.hasSubscription(type);
}
}

View File

@ -0,0 +1,42 @@
#pragma once
#include <Access/MemoryAccessStorage.h>
namespace Poco
{
namespace Util
{
class AbstractConfiguration;
}
}
namespace DB
{
/// Implementation of IAccessStorage which loads all from users.xml periodically.
class UsersConfigAccessStorage : public IAccessStorage
{
public:
UsersConfigAccessStorage();
~UsersConfigAccessStorage() override;
void loadFromConfig(const Poco::Util::AbstractConfiguration & config);
private:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override;
bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID & id) const override;
UUID insertImpl(const AccessEntityPtr & entity, bool replace_if_exists) override;
void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
SubscriptionPtr subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
SubscriptionPtr subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override;
MemoryAccessStorage memory_storage;
};
}

View File

@ -96,6 +96,7 @@ public:
void insertFrom(const IColumn & src, size_t n) override { data.push_back(static_cast<const Self &>(src).getData()[n]); }
void insertData(const char * pos, size_t /*length*/) override;
void insertDefault() override { data.push_back(T()); }
virtual void insertManyDefaults(size_t length) override { data.resize_fill(data.size() + length); }
void insert(const Field & x) override { data.push_back(DB::get<NearestFieldType<T>>(x)); }
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;

View File

@ -92,6 +92,11 @@ public:
chars.resize_fill(chars.size() + n);
}
virtual void insertManyDefaults(size_t length) override
{
chars.resize_fill(chars.size() + n * length);
}
void popBack(size_t elems) override
{
chars.resize_assume_reserved(chars.size() - n * elems);

View File

@ -205,6 +205,13 @@ public:
offsets.push_back(offsets.back() + 1);
}
virtual void insertManyDefaults(size_t length) override
{
chars.resize_fill(chars.size() + length);
for (size_t i = 0; i < length; ++i)
offsets.push_back(offsets.back() + 1);
}
int compareAt(size_t n, size_t m, const IColumn & rhs_, int /*nan_direction_hint*/) const override
{
const ColumnString & rhs = assert_cast<const ColumnString &>(rhs_);

View File

@ -144,6 +144,11 @@ public:
data.push_back(T());
}
virtual void insertManyDefaults(size_t length) override
{
data.resize_fill(data.size() + length, T());
}
void popBack(size_t n) override
{
data.resize_assume_reserved(data.size() - n);

View File

@ -465,6 +465,14 @@ namespace ErrorCodes
extern const int UNKNOWN_DICTIONARY = 488;
extern const int INCORRECT_DICTIONARY_DEFINITION = 489;
extern const int CANNOT_FORMAT_DATETIME = 490;
extern const int UNACCEPTABLE_URL = 491;
extern const int ACCESS_ENTITY_NOT_FOUND = 492;
extern const int ACCESS_ENTITY_ALREADY_EXISTS = 493;
extern const int ACCESS_ENTITY_FOUND_DUPLICATES = 494;
extern const int ACCESS_ENTITY_STORAGE_READONLY = 495;
extern const int QUOTA_REQUIRES_CLIENT_KEY = 496;
extern const int NOT_ENOUGH_PRIVILEGES = 497;
extern const int LIMIT_BY_WITH_TIES_IS_NOT_SUPPORTED = 498;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -0,0 +1,162 @@
#include <Common/IntervalKind.h>
#include <Common/Exception.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SYNTAX_ERROR;
}
const char * IntervalKind::toString() const
{
switch (kind)
{
case IntervalKind::Second: return "Second";
case IntervalKind::Minute: return "Minute";
case IntervalKind::Hour: return "Hour";
case IntervalKind::Day: return "Day";
case IntervalKind::Week: return "Week";
case IntervalKind::Month: return "Month";
case IntervalKind::Quarter: return "Quarter";
case IntervalKind::Year: return "Year";
}
__builtin_unreachable();
}
Int32 IntervalKind::toAvgSeconds() const
{
switch (kind)
{
case IntervalKind::Second: return 1;
case IntervalKind::Minute: return 60;
case IntervalKind::Hour: return 3600;
case IntervalKind::Day: return 86400;
case IntervalKind::Week: return 604800;
case IntervalKind::Month: return 2629746; /// Exactly 1/12 of a year.
case IntervalKind::Quarter: return 7889238; /// Exactly 1/4 of a year.
case IntervalKind::Year: return 31556952; /// The average length of a Gregorian year is equal to 365.2425 days
}
__builtin_unreachable();
}
IntervalKind IntervalKind::fromAvgSeconds(Int64 num_seconds)
{
if (num_seconds)
{
if (!(num_seconds % 31556952))
return IntervalKind::Year;
if (!(num_seconds % 7889238))
return IntervalKind::Quarter;
if (!(num_seconds % 604800))
return IntervalKind::Week;
if (!(num_seconds % 2629746))
return IntervalKind::Month;
if (!(num_seconds % 86400))
return IntervalKind::Day;
if (!(num_seconds % 3600))
return IntervalKind::Hour;
if (!(num_seconds % 60))
return IntervalKind::Minute;
}
return IntervalKind::Second;
}
const char * IntervalKind::toKeyword() const
{
switch (kind)
{
case IntervalKind::Second: return "SECOND";
case IntervalKind::Minute: return "MINUTE";
case IntervalKind::Hour: return "HOUR";
case IntervalKind::Day: return "DAY";
case IntervalKind::Week: return "WEEK";
case IntervalKind::Month: return "MONTH";
case IntervalKind::Quarter: return "QUARTER";
case IntervalKind::Year: return "YEAR";
}
__builtin_unreachable();
}
const char * IntervalKind::toDateDiffUnit() const
{
switch (kind)
{
case IntervalKind::Second:
return "second";
case IntervalKind::Minute:
return "minute";
case IntervalKind::Hour:
return "hour";
case IntervalKind::Day:
return "day";
case IntervalKind::Week:
return "week";
case IntervalKind::Month:
return "month";
case IntervalKind::Quarter:
return "quarter";
case IntervalKind::Year:
return "year";
}
__builtin_unreachable();
}
const char * IntervalKind::toNameOfFunctionToIntervalDataType() const
{
switch (kind)
{
case IntervalKind::Second:
return "toIntervalSecond";
case IntervalKind::Minute:
return "toIntervalMinute";
case IntervalKind::Hour:
return "toIntervalHour";
case IntervalKind::Day:
return "toIntervalDay";
case IntervalKind::Week:
return "toIntervalWeek";
case IntervalKind::Month:
return "toIntervalMonth";
case IntervalKind::Quarter:
return "toIntervalQuarter";
case IntervalKind::Year:
return "toIntervalYear";
}
__builtin_unreachable();
}
const char * IntervalKind::toNameOfFunctionExtractTimePart() const
{
switch (kind)
{
case IntervalKind::Second:
return "toSecond";
case IntervalKind::Minute:
return "toMinute";
case IntervalKind::Hour:
return "toHour";
case IntervalKind::Day:
return "toDayOfMonth";
case IntervalKind::Week:
// TODO: SELECT toRelativeWeekNum(toDate('2017-06-15')) - toRelativeWeekNum(toStartOfYear(toDate('2017-06-15')))
// else if (ParserKeyword("WEEK").ignore(pos, expected))
// function_name = "toRelativeWeekNum";
throw Exception("The syntax 'EXTRACT(WEEK FROM date)' is not supported, cannot extract the number of a week", ErrorCodes::SYNTAX_ERROR);
case IntervalKind::Month:
return "toMonth";
case IntervalKind::Quarter:
return "toQuarter";
case IntervalKind::Year:
return "toYear";
}
__builtin_unreachable();
}
}

View File

@ -0,0 +1,54 @@
#pragma once
#include <Core/Types.h>
namespace DB
{
/// Kind of a temporal interval.
struct IntervalKind
{
enum Kind
{
Second,
Minute,
Hour,
Day,
Week,
Month,
Quarter,
Year,
};
Kind kind = Second;
IntervalKind(Kind kind_ = Second) : kind(kind_) {}
operator Kind() const { return kind; }
const char * toString() const;
/// Returns number of seconds in one interval.
/// For `Month`, `Quarter` and `Year` the function returns an average number of seconds.
Int32 toAvgSeconds() const;
/// Chooses an interval kind based on number of seconds.
/// For example, `IntervalKind::fromAvgSeconds(3600)` returns `IntervalKind::Hour`.
static IntervalKind fromAvgSeconds(Int64 num_seconds);
/// Returns an uppercased version of what `toString()` returns.
const char * toKeyword() const;
/// Returns the string which can be passed to the `unit` parameter of the dateDiff() function.
/// For example, `IntervalKind{IntervalKind::Day}.getDateDiffParameter()` returns "day".
const char * toDateDiffUnit() const;
/// Returns the name of the function converting a number to the interval data type.
/// For example, `IntervalKind{IntervalKind::Day}.getToIntervalDataTypeFunctionName()`
/// returns "toIntervalDay".
const char * toNameOfFunctionToIntervalDataType() const;
/// Returns the name of the function extracting time part from a date or a time.
/// For example, `IntervalKind{IntervalKind::Day}.getExtractTimePartFunctionName()`
/// returns "toDayOfMonth".
const char * toNameOfFunctionExtractTimePart() const;
};
}

View File

@ -0,0 +1,62 @@
#include <re2/re2.h>
#include <Common/RemoteHostFilter.h>
#include <Poco/URI.h>
#include <Formats/FormatFactory.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/Exception.h>
#include <IO/WriteHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNACCEPTABLE_URL;
}
void RemoteHostFilter::checkURL(const Poco::URI & uri) const
{
if (!checkForDirectEntry(uri.getHost()) &&
!checkForDirectEntry(uri.getHost() + ":" + toString(uri.getPort())))
throw Exception("URL \"" + uri.toString() + "\" is not allowed in config.xml", ErrorCodes::UNACCEPTABLE_URL);
}
void RemoteHostFilter::checkHostAndPort(const std::string & host, const std::string & port) const
{
if (!checkForDirectEntry(host) &&
!checkForDirectEntry(host + ":" + port))
throw Exception("URL \"" + host + ":" + port + "\" is not allowed in config.xml", ErrorCodes::UNACCEPTABLE_URL);
}
void RemoteHostFilter::setValuesFromConfig(const Poco::Util::AbstractConfiguration & config)
{
if (config.has("remote_url_allow_hosts"))
{
std::vector<std::string> keys;
config.keys("remote_url_allow_hosts", keys);
for (auto key : keys)
{
if (startsWith(key, "host_regexp"))
regexp_hosts.push_back(config.getString("remote_url_allow_hosts." + key));
else if (startsWith(key, "host"))
primary_hosts.insert(config.getString("remote_url_allow_hosts." + key));
}
}
}
bool RemoteHostFilter::checkForDirectEntry(const std::string & str) const
{
if (!primary_hosts.empty() || !regexp_hosts.empty())
{
if (primary_hosts.find(str) == primary_hosts.end())
{
for (size_t i = 0; i < regexp_hosts.size(); ++i)
if (re2::RE2::FullMatch(str, regexp_hosts[i]))
return true;
return false;
}
return true;
}
return true;
}
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <vector>
#include <unordered_set>
#include <Poco/URI.h>
#include <Poco/Util/AbstractConfiguration.h>
namespace DB
{
class RemoteHostFilter
{
/**
* This class checks if url is allowed.
* If primary_hosts and regexp_hosts are empty all urls are allowed.
*/
public:
void checkURL(const Poco::URI & uri) const; /// If URL not allowed in config.xml throw UNACCEPTABLE_URL Exception
void setValuesFromConfig(const Poco::Util::AbstractConfiguration & config);
void checkHostAndPort(const std::string & host, const std::string & port) const; /// Does the same as checkURL, but for host and port.
private:
std::unordered_set<std::string> primary_hosts; /// Allowed primary (<host>) URL from config.xml
std::vector<std::string> regexp_hosts; /// Allowed regexp (<hots_regexp>) URL from config.xml
bool checkForDirectEntry(const std::string & str) const; /// Checks if the primary_hosts and regexp_hosts contain str. If primary_hosts and regexp_hosts are empty return true.
};
}

View File

@ -3,11 +3,18 @@
#include <cstdint>
#include <limits>
#include <Core/Defines.h>
// Also defined in Core/Defines.h
#if !defined(NO_SANITIZE_UNDEFINED)
#if defined(__clang__)
#define NO_SANITIZE_UNDEFINED __attribute__((__no_sanitize__("undefined")))
#else
#define NO_SANITIZE_UNDEFINED
#endif
#endif
/// On overlow, the function returns unspecified value.
inline NO_SANITIZE_UNDEFINED uint64_t intExp2(int x)
{
return 1ULL << x;

View File

@ -3,8 +3,10 @@
#include <type_traits>
#include <typeinfo>
#include <typeindex>
#include <memory>
#include <string>
#include <ext/shared_ptr_helper.h>
#include <Common/Exception.h>
#include <common/demangle.h>
@ -27,7 +29,7 @@ std::enable_if_t<std::is_reference_v<To>, To> typeid_cast(From & from)
{
try
{
if (typeid(from) == typeid(To))
if ((typeid(From) == typeid(To)) || (typeid(from) == typeid(To)))
return static_cast<To>(from);
}
catch (const std::exception & e)
@ -39,12 +41,13 @@ std::enable_if_t<std::is_reference_v<To>, To> typeid_cast(From & from)
DB::ErrorCodes::BAD_CAST);
}
template <typename To, typename From>
To typeid_cast(From * from)
std::enable_if_t<std::is_pointer_v<To>, To> typeid_cast(From * from)
{
try
{
if (typeid(*from) == typeid(std::remove_pointer_t<To>))
if ((typeid(From) == typeid(std::remove_pointer_t<To>)) || (typeid(*from) == typeid(std::remove_pointer_t<To>)))
return static_cast<To>(from);
else
return nullptr;
@ -54,3 +57,20 @@ To typeid_cast(From * from)
throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST);
}
}
template <typename To, typename From>
std::enable_if_t<ext::is_shared_ptr_v<To>, To> typeid_cast(const std::shared_ptr<From> & from)
{
try
{
if ((typeid(From) == typeid(typename To::element_type)) || (typeid(*from) == typeid(typename To::element_type)))
return std::static_pointer_cast<typename To::element_type>(from);
else
return nullptr;
}
catch (const std::exception & e)
{
throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST);
}
}

View File

@ -88,9 +88,9 @@ public:
Shift shift;
if (scale_a < scale_b)
shift.a = DataTypeDecimal<B>(maxDecimalPrecision<B>(), scale_b).getScaleMultiplier(scale_b - scale_a);
shift.a = B::getScaleMultiplier(scale_b - scale_a);
if (scale_a > scale_b)
shift.b = DataTypeDecimal<A>(maxDecimalPrecision<A>(), scale_a).getScaleMultiplier(scale_a - scale_b);
shift.b = A::getScaleMultiplier(scale_a - scale_b);
return applyWithScale(a, b, shift);
}

View File

@ -151,8 +151,8 @@
#endif
/// Marks that extra information is sent to a shard. It could be any magic numbers.
#define DBMS_DISTRIBUTED_SIGNATURE_EXTRA_INFO 0xCAFEDACEull
#define DBMS_DISTRIBUTED_SIGNATURE_SETTINGS_OLD_FORMAT 0xCAFECABEull
#define DBMS_DISTRIBUTED_SIGNATURE_HEADER 0xCAFEDACEull
#define DBMS_DISTRIBUTED_SIGNATURE_HEADER_OLD_FORMAT 0xCAFECABEull
#if !__has_include(<sanitizer/asan_interface.h>)
# define ASAN_UNPOISON_MEMORY_REGION(a, b)

View File

@ -300,21 +300,6 @@ namespace DB
}
template <> Decimal32 DecimalField<Decimal32>::getScaleMultiplier() const
{
return DataTypeDecimal<Decimal32>::getScaleMultiplier(scale);
}
template <> Decimal64 DecimalField<Decimal64>::getScaleMultiplier() const
{
return DataTypeDecimal<Decimal64>::getScaleMultiplier(scale);
}
template <> Decimal128 DecimalField<Decimal128>::getScaleMultiplier() const
{
return DataTypeDecimal<Decimal128>::getScaleMultiplier(scale);
}
template <typename T>
static bool decEqual(T x, T y, UInt32 x_scale, UInt32 y_scale)
{

View File

@ -102,7 +102,7 @@ public:
operator T() const { return dec; }
T getValue() const { return dec; }
T getScaleMultiplier() const;
T getScaleMultiplier() const { return T::getScaleMultiplier(scale); }
UInt32 getScale() const { return scale; }
template <typename U>

View File

@ -62,7 +62,7 @@ void SettingNumber<Type>::set(const Field & x)
template <typename Type>
void SettingNumber<Type>::set(const String & x)
{
set(parse<Type>(x));
set(completeParse<Type>(x));
}
template <>

View File

@ -4,6 +4,7 @@
#include <string>
#include <vector>
#include <common/Types.h>
#include <Common/intExp.h>
namespace DB
@ -145,6 +146,8 @@ struct Decimal
const Decimal<T> & operator /= (const T & x) { value /= x; return *this; }
const Decimal<T> & operator %= (const T & x) { value %= x; return *this; }
static T getScaleMultiplier(UInt32 scale);
T value;
};
@ -170,6 +173,10 @@ template <> struct NativeType<Decimal32> { using Type = Int32; };
template <> struct NativeType<Decimal64> { using Type = Int64; };
template <> struct NativeType<Decimal128> { using Type = Int128; };
template <> inline Int32 Decimal32::getScaleMultiplier(UInt32 scale) { return common::exp10_i32(scale); }
template <> inline Int64 Decimal64::getScaleMultiplier(UInt32 scale) { return common::exp10_i64(scale); }
template <> inline Int128 Decimal128::getScaleMultiplier(UInt32 scale) { return common::exp10_i128(scale); }
inline const char * getTypeName(TypeIndex idx)
{
switch (idx)

View File

@ -2,7 +2,7 @@
#include <Core/Field.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/Quota.h>
#include <Access/QuotaContext.h>
#include <Common/CurrentThread.h>
#include <common/sleep.h>
@ -70,7 +70,7 @@ Block IBlockInputStream::read()
if (limits.mode == LIMITS_CURRENT && !limits.size_limits.check(info.rows, info.bytes, "result", ErrorCodes::TOO_MANY_ROWS_OR_BYTES))
limit_exceeded_need_break = true;
if (quota != nullptr)
if (quota)
checkQuota(res);
}
else
@ -240,12 +240,8 @@ void IBlockInputStream::checkQuota(Block & block)
case LIMITS_CURRENT:
{
time_t current_time = time(nullptr);
double total_elapsed = info.total_stopwatch.elapsedSeconds();
quota->checkAndAddResultRowsBytes(current_time, block.rows(), block.bytes());
quota->checkAndAddExecutionTime(current_time, Poco::Timespan((total_elapsed - prev_elapsed) * 1000000.0));
UInt64 total_elapsed = info.total_stopwatch.elapsedNanoseconds();
quota->used({Quota::RESULT_ROWS, block.rows()}, {Quota::RESULT_BYTES, block.bytes()}, {Quota::EXECUTION_TIME, total_elapsed - prev_elapsed});
prev_elapsed = total_elapsed;
break;
}
@ -291,10 +287,8 @@ void IBlockInputStream::progressImpl(const Progress & value)
limits.speed_limits.throttle(progress.read_rows, progress.read_bytes, total_rows, total_elapsed_microseconds);
if (quota != nullptr && limits.mode == LIMITS_TOTAL)
{
quota->checkAndAddReadRowsBytes(time(nullptr), value.read_rows, value.read_bytes);
}
if (quota && limits.mode == LIMITS_TOTAL)
quota->used({Quota::READ_ROWS, value.read_rows}, {Quota::READ_BYTES, value.read_bytes});
}
}

View File

@ -23,7 +23,7 @@ namespace ErrorCodes
}
class ProcessListElement;
class QuotaForIntervals;
class QuotaContext;
class QueryStatus;
struct SortColumnDescription;
using SortDescription = std::vector<SortColumnDescription>;
@ -220,9 +220,9 @@ public:
/** Set the quota. If you set a quota on the amount of raw data,
* then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits.
*/
virtual void setQuota(QuotaForIntervals & quota_)
virtual void setQuota(const std::shared_ptr<QuotaContext> & quota_)
{
quota = &quota_;
quota = quota_;
}
/// Enable calculation of minimums and maximums by the result columns.
@ -273,8 +273,8 @@ private:
LocalLimits limits;
QuotaForIntervals * quota = nullptr; /// If nullptr - the quota is not used.
double prev_elapsed = 0;
std::shared_ptr<QuotaContext> quota; /// If nullptr - the quota is not used.
UInt64 prev_elapsed = 0;
/// The approximate total number of rows to read. For progress bar.
size_t total_rows_approx = 0;

View File

@ -1,5 +1,4 @@
#include <DataStreams/ParallelParsingBlockInputStream.h>
#include "ParallelParsingBlockInputStream.h"
namespace DB
{
@ -15,7 +14,7 @@ void ParallelParsingBlockInputStream::segmentatorThreadFunction()
auto & unit = processing_units[current_unit_number];
{
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
segmentator_condvar.wait(lock,
[&]{ return unit.status == READY_TO_INSERT || finished; });
}
@ -85,7 +84,7 @@ void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_unit_n
// except at the end of file. Also see a matching assert in readImpl().
assert(unit.is_last || unit.block_ext.block.size() > 0);
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
unit.status = READY_TO_READ;
reader_condvar.notify_all();
}
@ -99,7 +98,7 @@ void ParallelParsingBlockInputStream::onBackgroundException()
{
tryLogCurrentException(__PRETTY_FUNCTION__);
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
if (!background_exception)
{
background_exception = std::current_exception();
@ -116,7 +115,7 @@ Block ParallelParsingBlockInputStream::readImpl()
/**
* Check for background exception and rethrow it before we return.
*/
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
if (background_exception)
{
lock.unlock();
@ -134,7 +133,7 @@ Block ParallelParsingBlockInputStream::readImpl()
{
// We have read out all the Blocks from the previous Processing Unit,
// wait for the current one to become ready.
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
reader_condvar.wait(lock, [&](){ return unit.status == READY_TO_READ || finished; });
if (finished)
@ -190,7 +189,7 @@ Block ParallelParsingBlockInputStream::readImpl()
else
{
// Pass the unit back to the segmentator.
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
unit.status = READY_TO_INSERT;
segmentator_condvar.notify_all();
}

View File

@ -227,7 +227,7 @@ private:
finished = true;
{
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
segmentator_condvar.notify_all();
reader_condvar.notify_all();
}
@ -255,4 +255,4 @@ private:
void onBackgroundException();
};
};
}

View File

@ -78,7 +78,9 @@ SummingSortedBlockInputStream::SummingSortedBlockInputStream(
else
{
bool is_agg_func = WhichDataType(column.type).isAggregateFunction();
if (!column.type->isSummable() && !is_agg_func)
/// There are special const columns for example after prewere sections.
if ((!column.type->isSummable() && !is_agg_func) || isColumnConst(*column.column))
{
column_numbers_not_to_aggregate.push_back(i);
continue;
@ -198,6 +200,10 @@ SummingSortedBlockInputStream::SummingSortedBlockInputStream(
void SummingSortedBlockInputStream::insertCurrentRowIfNeeded(MutableColumns & merged_columns)
{
/// We have nothing to aggregate. It means that it could be non-zero, because we have columns_not_to_aggregate.
if (columns_to_aggregate.empty())
current_row_is_zero = false;
for (auto & desc : columns_to_aggregate)
{
// Do not insert if the aggregation state hasn't been created

View File

@ -13,14 +13,14 @@ bool DataTypeInterval::equals(const IDataType & rhs) const
void registerDataTypeInterval(DataTypeFactory & factory)
{
factory.registerSimpleDataType("IntervalSecond", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Second)); });
factory.registerSimpleDataType("IntervalMinute", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Minute)); });
factory.registerSimpleDataType("IntervalHour", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Hour)); });
factory.registerSimpleDataType("IntervalDay", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Day)); });
factory.registerSimpleDataType("IntervalWeek", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Week)); });
factory.registerSimpleDataType("IntervalMonth", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Month)); });
factory.registerSimpleDataType("IntervalQuarter", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Quarter)); });
factory.registerSimpleDataType("IntervalYear", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(DataTypeInterval::Year)); });
factory.registerSimpleDataType("IntervalSecond", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Second)); });
factory.registerSimpleDataType("IntervalMinute", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Minute)); });
factory.registerSimpleDataType("IntervalHour", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Hour)); });
factory.registerSimpleDataType("IntervalDay", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Day)); });
factory.registerSimpleDataType("IntervalWeek", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Week)); });
factory.registerSimpleDataType("IntervalMonth", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Month)); });
factory.registerSimpleDataType("IntervalQuarter", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Quarter)); });
factory.registerSimpleDataType("IntervalYear", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Year)); });
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <DataTypes/DataTypeNumberBase.h>
#include <Common/IntervalKind.h>
namespace DB
@ -16,47 +17,17 @@ namespace DB
*/
class DataTypeInterval final : public DataTypeNumberBase<Int64>
{
public:
enum Kind
{
Second,
Minute,
Hour,
Day,
Week,
Month,
Quarter,
Year
};
private:
Kind kind;
IntervalKind kind;
public:
static constexpr bool is_parametric = true;
Kind getKind() const { return kind; }
IntervalKind getKind() const { return kind; }
const char * kindToString() const
{
switch (kind)
{
case Second: return "Second";
case Minute: return "Minute";
case Hour: return "Hour";
case Day: return "Day";
case Week: return "Week";
case Month: return "Month";
case Quarter: return "Quarter";
case Year: return "Year";
}
DataTypeInterval(IntervalKind kind_) : kind(kind_) {}
__builtin_unreachable();
}
DataTypeInterval(Kind kind_) : kind(kind_) {}
std::string doGetName() const override { return std::string("Interval") + kindToString(); }
std::string doGetName() const override { return std::string("Interval") + kind.toString(); }
const char * getFamilyName() const override { return "Interval"; }
TypeIndex getTypeId() const override { return TypeIndex::Interval; }

View File

@ -58,7 +58,7 @@ bool DataTypeDecimal<T>::tryReadText(T & x, ReadBuffer & istr, UInt32 precision,
{
UInt32 unread_scale = scale;
bool done = tryReadDecimalText(istr, x, precision, unread_scale);
x *= getScaleMultiplier(unread_scale);
x *= T::getScaleMultiplier(unread_scale);
return done;
}
@ -70,7 +70,7 @@ void DataTypeDecimal<T>::readText(T & x, ReadBuffer & istr, UInt32 precision, UI
readCSVDecimalText(istr, x, precision, unread_scale);
else
readDecimalText(istr, x, precision, unread_scale);
x *= getScaleMultiplier(unread_scale);
x *= T::getScaleMultiplier(unread_scale);
}
template <typename T>
@ -96,7 +96,7 @@ T DataTypeDecimal<T>::parseFromString(const String & str) const
T x;
UInt32 unread_scale = scale;
readDecimalText(buf, x, precision, unread_scale, true);
x *= getScaleMultiplier(unread_scale);
x *= T::getScaleMultiplier(unread_scale);
return x;
}
@ -271,25 +271,6 @@ void registerDataTypeDecimal(DataTypeFactory & factory)
}
template <>
Decimal32 DataTypeDecimal<Decimal32>::getScaleMultiplier(UInt32 scale_)
{
return decimalScaleMultiplier<Int32>(scale_);
}
template <>
Decimal64 DataTypeDecimal<Decimal64>::getScaleMultiplier(UInt32 scale_)
{
return decimalScaleMultiplier<Int64>(scale_);
}
template <>
Decimal128 DataTypeDecimal<Decimal128>::getScaleMultiplier(UInt32 scale_)
{
return decimalScaleMultiplier<Int128>(scale_);
}
/// Explicit template instantiations.
template class DataTypeDecimal<Decimal32>;
template class DataTypeDecimal<Decimal64>;

View File

@ -130,7 +130,7 @@ public:
UInt32 getPrecision() const { return precision; }
UInt32 getScale() const { return scale; }
T getScaleMultiplier() const { return getScaleMultiplier(scale); }
T getScaleMultiplier() const { return T::getScaleMultiplier(scale); }
T wholePart(T x) const
{
@ -148,7 +148,7 @@ public:
return x % getScaleMultiplier();
}
T maxWholeValue() const { return getScaleMultiplier(maxPrecision() - scale) - T(1); }
T maxWholeValue() const { return T::getScaleMultiplier(maxPrecision() - scale) - T(1); }
bool canStoreWhole(T x) const
{
@ -165,7 +165,7 @@ public:
if (getScale() < x.getScale())
throw Exception("Decimal result's scale is less then argiment's one", ErrorCodes::ARGUMENT_OUT_OF_BOUND);
UInt32 scale_delta = getScale() - x.getScale(); /// scale_delta >= 0
return getScaleMultiplier(scale_delta);
return T::getScaleMultiplier(scale_delta);
}
template <typename U>
@ -181,7 +181,6 @@ public:
void readText(T & x, ReadBuffer & istr, bool csv = false) const { readText(x, istr, precision, scale, csv); }
static void readText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale, bool csv = false);
static bool tryReadText(T & x, ReadBuffer & istr, UInt32 precision, UInt32 scale);
static T getScaleMultiplier(UInt32 scale);
private:
const UInt32 precision;
@ -264,12 +263,12 @@ convertDecimals(const typename FromDataType::FieldType & value, UInt32 scale_fro
MaxNativeType converted_value;
if (scale_to > scale_from)
{
converted_value = DataTypeDecimal<MaxFieldType>::getScaleMultiplier(scale_to - scale_from);
converted_value = MaxFieldType::getScaleMultiplier(scale_to - scale_from);
if (common::mulOverflow(static_cast<MaxNativeType>(value), converted_value, converted_value))
throw Exception("Decimal convert overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
else
converted_value = value / DataTypeDecimal<MaxFieldType>::getScaleMultiplier(scale_from - scale_to);
converted_value = value / MaxFieldType::getScaleMultiplier(scale_from - scale_to);
if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType))
{
@ -289,7 +288,7 @@ convertFromDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
using ToFieldType = typename ToDataType::FieldType;
if constexpr (std::is_floating_point_v<ToFieldType>)
return static_cast<ToFieldType>(value) / FromDataType::getScaleMultiplier(scale);
return static_cast<ToFieldType>(value) / FromFieldType::getScaleMultiplier(scale);
else
{
FromFieldType converted_value = convertDecimals<FromDataType, FromDataType>(value, scale, 0);
@ -320,14 +319,15 @@ inline std::enable_if_t<IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDa
convertToDecimal(const typename FromDataType::FieldType & value, UInt32 scale)
{
using FromFieldType = typename FromDataType::FieldType;
using ToNativeType = typename ToDataType::FieldType::NativeType;
using ToFieldType = typename ToDataType::FieldType;
using ToNativeType = typename ToFieldType::NativeType;
if constexpr (std::is_floating_point_v<FromFieldType>)
{
if (!std::isfinite(value))
throw Exception("Decimal convert overflow. Cannot convert infinity or NaN to decimal", ErrorCodes::DECIMAL_OVERFLOW);
auto out = value * ToDataType::getScaleMultiplier(scale);
auto out = value * ToFieldType::getScaleMultiplier(scale);
if constexpr (std::is_same_v<ToNativeType, Int128>)
{
static constexpr __int128 min_int128 = __int128(0x8000000000000000ll) << 64;

View File

@ -3,8 +3,8 @@
#include <Columns/ColumnsNumber.h>
#include <Common/ProfilingScopedRWLock.h>
#include <Common/typeid_cast.h>
#include <common/DateLUT.h>
#include <DataStreams/IBlockInputStream.h>
#include <ext/chrono_io.h>
#include <ext/map.h>
#include <ext/range.h>
#include <ext/size.h>
@ -334,7 +334,7 @@ void CacheDictionary::update(
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
tryLogException(last_exception, log, "Could not update cache dictionary '" + getName() +
"', next update is scheduled at " + DateLUT::instance().timeToString(std::chrono::system_clock::to_time_t(backoff_end_time)));
"', next update is scheduled at " + ext::to_string(backoff_end_time));
}
}

View File

@ -281,6 +281,8 @@ void registerInputFormatProcessorTSKV(FormatFactory & factory);
void registerOutputFormatProcessorTSKV(FormatFactory & factory);
void registerInputFormatProcessorJSONEachRow(FormatFactory & factory);
void registerOutputFormatProcessorJSONEachRow(FormatFactory & factory);
void registerInputFormatProcessorJSONCompactEachRow(FormatFactory & factory);
void registerOutputFormatProcessorJSONCompactEachRow(FormatFactory & factory);
void registerInputFormatProcessorParquet(FormatFactory & factory);
void registerInputFormatProcessorORC(FormatFactory & factory);
void registerOutputFormatProcessorParquet(FormatFactory & factory);
@ -336,6 +338,8 @@ FormatFactory::FormatFactory()
registerOutputFormatProcessorTSKV(*this);
registerInputFormatProcessorJSONEachRow(*this);
registerOutputFormatProcessorJSONEachRow(*this);
registerInputFormatProcessorJSONCompactEachRow(*this);
registerOutputFormatProcessorJSONCompactEachRow(*this);
registerInputFormatProcessorProtobuf(*this);
registerOutputFormatProcessorProtobuf(*this);
registerInputFormatProcessorCapnProto(*this);

View File

@ -508,7 +508,7 @@ class FunctionBinaryArithmetic : public IFunction
}
std::stringstream function_name;
function_name << (function_is_plus ? "add" : "subtract") << interval_data_type->kindToString() << 's';
function_name << (function_is_plus ? "add" : "subtract") << interval_data_type->getKind().toString() << 's';
return FunctionFactory::instance().get(function_name.str(), context);
}

View File

@ -735,7 +735,7 @@ struct NameToDecimal128 { static constexpr auto name = "toDecimal128"; };
struct NameToInterval ## INTERVAL_KIND \
{ \
static constexpr auto name = "toInterval" #INTERVAL_KIND; \
static constexpr int kind = DataTypeInterval::INTERVAL_KIND; \
static constexpr auto kind = IntervalKind::INTERVAL_KIND; \
};
DEFINE_NAME_TO_INTERVAL(Second)
@ -786,7 +786,7 @@ public:
if constexpr (std::is_same_v<ToDataType, DataTypeInterval>)
{
return std::make_shared<DataTypeInterval>(DataTypeInterval::Kind(Name::kind));
return std::make_shared<DataTypeInterval>(Name::kind);
}
else if constexpr (to_decimal)
{

View File

@ -0,0 +1,134 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeUUID.h>
#include <Access/QuotaContext.h>
#include <Core/Field.h>
namespace DB
{
class FunctionCurrentQuota : public IFunction
{
const String quota_name;
public:
static constexpr auto name = "currentQuota";
static FunctionPtr create(const Context & context)
{
return std::make_shared<FunctionCurrentQuota>(context.getQuota()->getUsageInfo().quota_name);
}
explicit FunctionCurrentQuota(const String & quota_name_) : quota_name{quota_name_}
{
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 0;
}
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
{
return std::make_shared<DataTypeString>();
}
bool isDeterministic() const override { return false; }
void executeImpl(Block & block, const ColumnNumbers &, size_t result, size_t input_rows_count) override
{
block.getByPosition(result).column = DataTypeString().createColumnConst(input_rows_count, quota_name);
}
};
class FunctionCurrentQuotaId : public IFunction
{
const UUID quota_id;
public:
static constexpr auto name = "currentQuotaID";
static FunctionPtr create(const Context & context)
{
return std::make_shared<FunctionCurrentQuotaId>(context.getQuota()->getUsageInfo().quota_id);
}
explicit FunctionCurrentQuotaId(const UUID quota_id_) : quota_id{quota_id_}
{
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 0;
}
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
{
return std::make_shared<DataTypeUUID>();
}
bool isDeterministic() const override { return false; }
void executeImpl(Block & block, const ColumnNumbers &, size_t result, size_t input_rows_count) override
{
block.getByPosition(result).column = DataTypeUUID().createColumnConst(input_rows_count, quota_id);
}
};
class FunctionCurrentQuotaKey : public IFunction
{
const String quota_key;
public:
static constexpr auto name = "currentQuotaKey";
static FunctionPtr create(const Context & context)
{
return std::make_shared<FunctionCurrentQuotaKey>(context.getQuota()->getUsageInfo().quota_key);
}
explicit FunctionCurrentQuotaKey(const String & quota_key_) : quota_key{quota_key_}
{
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 0;
}
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
{
return std::make_shared<DataTypeString>();
}
bool isDeterministic() const override { return false; }
void executeImpl(Block & block, const ColumnNumbers &, size_t result, size_t input_rows_count) override
{
block.getByPosition(result).column = DataTypeString().createColumnConst(input_rows_count, quota_key);
}
};
void registerFunctionCurrentQuota(FunctionFactory & factory)
{
factory.registerFunction<FunctionCurrentQuota>();
factory.registerFunction<FunctionCurrentQuotaId>();
factory.registerFunction<FunctionCurrentQuotaKey>();
}
}

View File

@ -7,6 +7,7 @@ class FunctionFactory;
void registerFunctionCurrentDatabase(FunctionFactory &);
void registerFunctionCurrentUser(FunctionFactory &);
void registerFunctionCurrentQuota(FunctionFactory &);
void registerFunctionHostName(FunctionFactory &);
void registerFunctionFQDN(FunctionFactory &);
void registerFunctionVisibleWidth(FunctionFactory &);
@ -62,6 +63,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
{
registerFunctionCurrentDatabase(factory);
registerFunctionCurrentUser(factory);
registerFunctionCurrentQuota(factory);
registerFunctionHostName(factory);
registerFunctionFQDN(factory);
registerFunctionVisibleWidth(factory);

View File

@ -23,11 +23,11 @@ namespace
{
static constexpr auto function_name = "toStartOfInterval";
template <DataTypeInterval::Kind unit>
template <IntervalKind::Kind unit>
struct Transform;
template <>
struct Transform<DataTypeInterval::Year>
struct Transform<IntervalKind::Year>
{
static UInt16 execute(UInt16 d, UInt64 years, const DateLUTImpl & time_zone)
{
@ -41,7 +41,7 @@ namespace
};
template <>
struct Transform<DataTypeInterval::Quarter>
struct Transform<IntervalKind::Quarter>
{
static UInt16 execute(UInt16 d, UInt64 quarters, const DateLUTImpl & time_zone)
{
@ -55,7 +55,7 @@ namespace
};
template <>
struct Transform<DataTypeInterval::Month>
struct Transform<IntervalKind::Month>
{
static UInt16 execute(UInt16 d, UInt64 months, const DateLUTImpl & time_zone)
{
@ -69,7 +69,7 @@ namespace
};
template <>
struct Transform<DataTypeInterval::Week>
struct Transform<IntervalKind::Week>
{
static UInt16 execute(UInt16 d, UInt64 weeks, const DateLUTImpl & time_zone)
{
@ -83,7 +83,7 @@ namespace
};
template <>
struct Transform<DataTypeInterval::Day>
struct Transform<IntervalKind::Day>
{
static UInt32 execute(UInt16 d, UInt64 days, const DateLUTImpl & time_zone)
{
@ -97,7 +97,7 @@ namespace
};
template <>
struct Transform<DataTypeInterval::Hour>
struct Transform<IntervalKind::Hour>
{
static UInt32 execute(UInt16, UInt64, const DateLUTImpl &) { return dateIsNotSupported(function_name); }
@ -105,7 +105,7 @@ namespace
};
template <>
struct Transform<DataTypeInterval::Minute>
struct Transform<IntervalKind::Minute>
{
static UInt32 execute(UInt16, UInt64, const DateLUTImpl &) { return dateIsNotSupported(function_name); }
@ -116,7 +116,7 @@ namespace
};
template <>
struct Transform<DataTypeInterval::Second>
struct Transform<IntervalKind::Second>
{
static UInt32 execute(UInt16, UInt64, const DateLUTImpl &) { return dateIsNotSupported(function_name); }
@ -163,9 +163,9 @@ public:
"Illegal type " + arguments[1].type->getName() + " of argument of function " + getName()
+ ". Should be an interval of time",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
result_type_is_date = (interval_type->getKind() == DataTypeInterval::Year)
|| (interval_type->getKind() == DataTypeInterval::Quarter) || (interval_type->getKind() == DataTypeInterval::Month)
|| (interval_type->getKind() == DataTypeInterval::Week);
result_type_is_date = (interval_type->getKind() == IntervalKind::Year)
|| (interval_type->getKind() == IntervalKind::Quarter) || (interval_type->getKind() == IntervalKind::Month)
|| (interval_type->getKind() == IntervalKind::Week);
};
auto check_timezone_argument = [&]
@ -177,7 +177,7 @@ public:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (first_argument_is_date && result_type_is_date)
throw Exception(
"The timezone argument of function " + getName() + " with interval type " + interval_type->kindToString()
"The timezone argument of function " + getName() + " with interval type " + interval_type->getKind().toString()
+ " is allowed only when the 1st argument has the type DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
};
@ -269,28 +269,28 @@ private:
switch (interval_type->getKind())
{
case DataTypeInterval::Second:
return execute<FromType, UInt32, DataTypeInterval::Second>(time_column, num_units, time_zone);
case DataTypeInterval::Minute:
return execute<FromType, UInt32, DataTypeInterval::Minute>(time_column, num_units, time_zone);
case DataTypeInterval::Hour:
return execute<FromType, UInt32, DataTypeInterval::Hour>(time_column, num_units, time_zone);
case DataTypeInterval::Day:
return execute<FromType, UInt32, DataTypeInterval::Day>(time_column, num_units, time_zone);
case DataTypeInterval::Week:
return execute<FromType, UInt16, DataTypeInterval::Week>(time_column, num_units, time_zone);
case DataTypeInterval::Month:
return execute<FromType, UInt16, DataTypeInterval::Month>(time_column, num_units, time_zone);
case DataTypeInterval::Quarter:
return execute<FromType, UInt16, DataTypeInterval::Quarter>(time_column, num_units, time_zone);
case DataTypeInterval::Year:
return execute<FromType, UInt16, DataTypeInterval::Year>(time_column, num_units, time_zone);
case IntervalKind::Second:
return execute<FromType, UInt32, IntervalKind::Second>(time_column, num_units, time_zone);
case IntervalKind::Minute:
return execute<FromType, UInt32, IntervalKind::Minute>(time_column, num_units, time_zone);
case IntervalKind::Hour:
return execute<FromType, UInt32, IntervalKind::Hour>(time_column, num_units, time_zone);
case IntervalKind::Day:
return execute<FromType, UInt32, IntervalKind::Day>(time_column, num_units, time_zone);
case IntervalKind::Week:
return execute<FromType, UInt16, IntervalKind::Week>(time_column, num_units, time_zone);
case IntervalKind::Month:
return execute<FromType, UInt16, IntervalKind::Month>(time_column, num_units, time_zone);
case IntervalKind::Quarter:
return execute<FromType, UInt16, IntervalKind::Quarter>(time_column, num_units, time_zone);
case IntervalKind::Year:
return execute<FromType, UInt16, IntervalKind::Year>(time_column, num_units, time_zone);
}
__builtin_unreachable();
}
template <typename FromType, typename ToType, DataTypeInterval::Kind unit>
template <typename FromType, typename ToType, IntervalKind::Kind unit>
ColumnPtr execute(const ColumnVector<FromType> & time_column, UInt64 num_units, const DateLUTImpl & time_zone)
{
const auto & time_data = time_column.getData();

View File

@ -14,10 +14,12 @@ const int DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT = 2;
ReadBufferFromS3::ReadBufferFromS3(const Poco::URI & uri_,
const String & access_key_id_,
const String & secret_access_key_,
const ConnectionTimeouts & timeouts)
const ConnectionTimeouts & timeouts,
const RemoteHostFilter & remote_host_filter_)
: ReadBuffer(nullptr, 0)
, uri {uri_}
, session {makeHTTPSession(uri_, timeouts)}
, remote_host_filter {remote_host_filter_}
{
Poco::Net::HTTPResponse response;
std::unique_ptr<Poco::Net::HTTPRequest> request;
@ -52,6 +54,7 @@ ReadBufferFromS3::ReadBufferFromS3(const Poco::URI & uri_,
break;
uri = location_iterator->second;
remote_host_filter.checkURL(uri);
session = makeHTTPSession(uri, timeouts);
}

View File

@ -21,11 +21,14 @@ protected:
std::istream * istr; /// owned by session
std::unique_ptr<ReadBuffer> impl;
RemoteHostFilter remote_host_filter;
public:
explicit ReadBufferFromS3(const Poco::URI & uri_,
const String & access_key_id_,
const String & secret_access_key_,
const ConnectionTimeouts & timeouts = {});
const ConnectionTimeouts & timeouts = {},
const RemoteHostFilter & remote_host_filter_ = {});
bool nextImpl() override;
};

View File

@ -877,6 +877,30 @@ inline T parse(const char * data, size_t size)
return res;
}
/// Read something from text format, but expect complete parse of given text
/// For example: 723145 -- ok, 213MB -- not ok
template <typename T>
inline T completeParse(const char * data, size_t size)
{
T res;
ReadBufferFromMemory buf(data, size);
readText(res, buf);
assertEOF(buf);
return res;
}
template <typename T>
inline T completeParse(const String & s)
{
return completeParse<T>(s.data(), s.size());
}
template <typename T>
inline T completeParse(const char * data)
{
return completeParse<T>(data, strlen(data));
}
template <typename T>
inline T parse(const char * data)
{
@ -916,12 +940,12 @@ void skipToUnescapedNextLineOrEOF(ReadBuffer & buf);
template <class TReadBuffer, class... Types>
std::unique_ptr<ReadBuffer> getReadBuffer(const DB::CompressionMethod method, Types&&... args)
{
if (method == DB::CompressionMethod::Gzip)
{
auto read_buf = std::make_unique<TReadBuffer>(std::forward<Types>(args)...);
return std::make_unique<ZlibInflatingReadBuffer>(std::move(read_buf), method);
}
return std::make_unique<TReadBuffer>(args...);
if (method == DB::CompressionMethod::Gzip)
{
auto read_buf = std::make_unique<TReadBuffer>(std::forward<Types>(args)...);
return std::make_unique<ZlibInflatingReadBuffer>(std::move(read_buf), method);
}
return std::make_unique<TReadBuffer>(args...);
}
/** This function just copies the data from buffer's internal position (in.position())

View File

@ -101,6 +101,7 @@ namespace detail
const Poco::Net::HTTPBasicCredentials & credentials;
std::vector<Poco::Net::HTTPCookie> cookies;
HTTPHeaderEntries http_header_entries;
RemoteHostFilter remote_host_filter;
std::istream * call(const Poco::URI uri_, Poco::Net::HTTPResponse & response)
{
@ -157,7 +158,8 @@ namespace detail
OutStreamCallback out_stream_callback_ = {},
const Poco::Net::HTTPBasicCredentials & credentials_ = {},
size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE,
HTTPHeaderEntries http_header_entries_ = {})
HTTPHeaderEntries http_header_entries_ = {},
const RemoteHostFilter & remote_host_filter_ = {})
: ReadBuffer(nullptr, 0)
, uri {uri_}
, method {!method_.empty() ? method_ : out_stream_callback_ ? Poco::Net::HTTPRequest::HTTP_POST : Poco::Net::HTTPRequest::HTTP_GET}
@ -165,6 +167,7 @@ namespace detail
, out_stream_callback {out_stream_callback_}
, credentials {credentials_}
, http_header_entries {http_header_entries_}
, remote_host_filter {remote_host_filter_}
{
Poco::Net::HTTPResponse response;
@ -173,6 +176,7 @@ namespace detail
while (isRedirect(response.getStatus()))
{
Poco::URI uri_redirect(response.get("Location"));
remote_host_filter.checkURL(uri_redirect);
session->updateSession(uri_redirect);
@ -243,8 +247,9 @@ public:
const DB::SettingUInt64 max_redirects = 0,
const Poco::Net::HTTPBasicCredentials & credentials_ = {},
size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE,
const HTTPHeaderEntries & http_header_entries_ = {})
: Parent(std::make_shared<UpdatableSession>(uri_, timeouts, max_redirects), uri_, method_, out_stream_callback_, credentials_, buffer_size_, http_header_entries_)
const HTTPHeaderEntries & http_header_entries_ = {},
const RemoteHostFilter & remote_host_filter_ = {})
: Parent(std::make_shared<UpdatableSession>(uri_, timeouts, max_redirects), uri_, method_, out_stream_callback_, credentials_, buffer_size_, http_header_entries_, remote_host_filter_)
{
}
};

View File

@ -34,7 +34,8 @@ WriteBufferFromS3::WriteBufferFromS3(
const String & access_key_id_,
const String & secret_access_key_,
size_t minimum_upload_part_size_,
const ConnectionTimeouts & timeouts_)
const ConnectionTimeouts & timeouts_,
const RemoteHostFilter & remote_host_filter_)
: BufferWithOwnMemory<WriteBuffer>(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0)
, uri {uri_}
, access_key_id {access_key_id_}
@ -43,6 +44,7 @@ WriteBufferFromS3::WriteBufferFromS3(
, timeouts {timeouts_}
, temporary_buffer {std::make_unique<WriteBufferFromString>(buffer_string)}
, last_part_size {0}
, remote_host_filter(remote_host_filter_)
{
initiate();
@ -134,6 +136,7 @@ void WriteBufferFromS3::initiate()
break;
initiate_uri = location_iterator->second;
remote_host_filter.checkURL(initiate_uri);
}
assertResponseIsOk(*request_ptr, response, *istr);

View File

@ -28,6 +28,7 @@ private:
String buffer_string;
std::unique_ptr<WriteBufferFromString> temporary_buffer;
size_t last_part_size;
RemoteHostFilter remote_host_filter;
/// Upload in S3 is made in parts.
/// We initiate upload, then upload each part and get ETag as a response, and then finish upload with listing all our parts.
@ -39,7 +40,8 @@ public:
const String & access_key_id,
const String & secret_access_key,
size_t minimum_upload_part_size_,
const ConnectionTimeouts & timeouts = {});
const ConnectionTimeouts & timeouts = {},
const RemoteHostFilter & remote_host_filter_ = {});
void nextImpl() override;

View File

@ -18,7 +18,6 @@
#include <Common/Exception.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/UInt128.h>
#include <Common/intExp.h>
#include <IO/CompressionMethod.h>
#include <IO/WriteBuffer.h>
@ -764,12 +763,6 @@ inline void writeText(const LocalDateTime & x, WriteBuffer & buf) { writeDateTim
inline void writeText(const UUID & x, WriteBuffer & buf) { writeUUIDText(x, buf); }
inline void writeText(const UInt128 & x, WriteBuffer & buf) { writeText(UUID(x), buf); }
template <typename T> inline T decimalScaleMultiplier(UInt32 scale);
template <> inline Int32 decimalScaleMultiplier<Int32>(UInt32 scale) { return common::exp10_i32(scale); }
template <> inline Int64 decimalScaleMultiplier<Int64>(UInt32 scale) { return common::exp10_i64(scale); }
template <> inline Int128 decimalScaleMultiplier<Int128>(UInt32 scale) { return common::exp10_i128(scale); }
template <typename T>
void writeText(Decimal<T> value, UInt32 scale, WriteBuffer & ostr)
{
@ -781,7 +774,7 @@ void writeText(Decimal<T> value, UInt32 scale, WriteBuffer & ostr)
T whole_part = value;
if (scale)
whole_part = value / decimalScaleMultiplier<T>(scale);
whole_part = value / Decimal<T>::getScaleMultiplier(scale);
writeIntText(whole_part, ostr);
if (scale)

View File

@ -5,6 +5,7 @@
#include <Poco/Mutex.h>
#include <Poco/UUID.h>
#include <Poco/Net/IPAddress.h>
#include <Poco/Util/Application.h>
#include <Common/Macros.h>
#include <Common/escapeForFileName.h>
#include <Common/setThreadName.h>
@ -24,9 +25,11 @@
#include <TableFunctions/TableFunctionFactory.h>
#include <Interpreters/ActionLocksManager.h>
#include <Core/Settings.h>
#include <Access/AccessControlManager.h>
#include <Access/SettingsConstraints.h>
#include <Access/QuotaContext.h>
#include <Interpreters/ExpressionJIT.h>
#include <Interpreters/UsersManager.h>
#include <Interpreters/Quota.h>
#include <Dictionaries/Embedded/GeoDictionariesLoader.h>
#include <Interpreters/EmbeddedDictionaries.h>
#include <Interpreters/ExternalLoaderXMLConfigRepository.h>
@ -37,7 +40,6 @@
#include <Interpreters/ProcessList.h>
#include <Interpreters/Cluster.h>
#include <Interpreters/InterserverIOHandler.h>
#include <Access/SettingsConstraints.h>
#include <Interpreters/SystemLog.h>
#include <Interpreters/Context.h>
#include <Interpreters/DDLWorker.h>
@ -53,7 +55,7 @@
#include <Common/ShellCommand.h>
#include <Common/TraceCollector.h>
#include <common/logger_useful.h>
#include <Common/RemoteHostFilter.h>
namespace ProfileEvents
{
@ -91,6 +93,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
extern const int SCALAR_ALREADY_EXISTS;
extern const int UNKNOWN_SCALAR;
extern const int NOT_ENOUGH_PRIVILEGES;
}
@ -130,8 +133,8 @@ struct ContextShared
mutable std::optional<ExternalModelsLoader> external_models_loader;
String default_profile_name; /// Default profile name used for default values.
String system_profile_name; /// Profile used by system processes
AccessControlManager access_control_manager;
std::unique_ptr<UsersManager> users_manager; /// Known users.
Quotas quotas; /// Known quotas for resource use.
mutable UncompressedCachePtr uncompressed_cache; /// The cache of decompressed blocks.
mutable MarkCachePtr mark_cache; /// Cache of marks in compressed files.
ProcessList process_list; /// Executing queries at the moment.
@ -158,8 +161,9 @@ struct ContextShared
ActionLocksManagerPtr action_locks_manager; /// Set of storages' action lockers
std::optional<SystemLogs> system_logs; /// Used to log queries and operations on parts
std::unique_ptr<TraceCollector> trace_collector; /// Thread collecting traces from threads executing queries
RemoteHostFilter remote_host_filter; /// Allowed URL from config.xml
std::unique_ptr<TraceCollector> trace_collector; /// Thread collecting traces from threads executing queries
/// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests.
class SessionKeyHash
@ -325,7 +329,7 @@ Context & Context::operator=(const Context &) = default;
Context Context::createGlobal()
{
Context res;
res.quota = std::make_shared<QuotaForIntervals>();
res.quota = std::make_shared<QuotaContext>();
res.shared = std::make_shared<ContextShared>();
return res;
}
@ -584,12 +588,31 @@ const Poco::Util::AbstractConfiguration & Context::getConfigRef() const
return shared->config ? *shared->config : Poco::Util::Application::instance().config();
}
AccessControlManager & Context::getAccessControlManager()
{
auto lock = getLock();
return shared->access_control_manager;
}
const AccessControlManager & Context::getAccessControlManager() const
{
auto lock = getLock();
return shared->access_control_manager;
}
void Context::checkQuotaManagementIsAllowed()
{
if (!is_quota_management_allowed)
throw Exception(
"User " + client_info.current_user + " doesn't have enough privileges to manage quotas", ErrorCodes::NOT_ENOUGH_PRIVILEGES);
}
void Context::setUsersConfig(const ConfigurationPtr & config)
{
auto lock = getLock();
shared->users_config = config;
shared->access_control_manager.loadFromConfig(*shared->users_config);
shared->users_manager->loadFromConfig(*shared->users_config);
shared->quotas.loadFromConfig(*shared->users_config);
}
ConfigurationPtr Context::getUsersConfig()
@ -630,7 +653,8 @@ void Context::calculateUserSettings()
{
auto lock = getLock();
String profile = shared->users_manager->getUser(client_info.current_user)->profile;
auto user = getUser(client_info.current_user);
String profile = user->profile;
/// 1) Set default settings (hardcoded values)
/// NOTE: we ignore global_context settings (from which it is usually copied)
@ -645,6 +669,10 @@ void Context::calculateUserSettings()
/// 3) Apply settings from current user
setProfile(profile);
quota = getAccessControlManager().createQuotaContext(
client_info.current_user, client_info.current_address.host(), client_info.quota_key);
is_quota_management_allowed = user->is_quota_management_allowed;
}
@ -677,24 +705,9 @@ void Context::setUser(const String & name, const String & password, const Poco::
client_info.quota_key = quota_key;
calculateUserSettings();
setQuota(user_props->quota, quota_key, name, address.host());
}
void Context::setQuota(const String & name, const String & quota_key, const String & user_name, const Poco::Net::IPAddress & address)
{
auto lock = getLock();
quota = shared->quotas.get(name, quota_key, user_name, address);
}
QuotaForIntervals & Context::getQuota()
{
auto lock = getLock();
return *quota;
}
void Context::checkDatabaseAccessRights(const std::string & database_name) const
{
auto lock = getLock();
@ -1583,6 +1596,16 @@ String Context::getInterserverScheme() const
return shared->interserver_scheme;
}
void Context::setRemoteHostFilter(const Poco::Util::AbstractConfiguration & config)
{
shared->remote_host_filter.setValuesFromConfig(config);
}
const RemoteHostFilter & Context::getRemoteHostFilter() const
{
return shared->remote_host_filter;
}
UInt16 Context::getTCPPort() const
{
auto lock = getLock();

View File

@ -22,6 +22,7 @@
#include <mutex>
#include <optional>
#include <thread>
#include <Common/RemoteHostFilter.h>
namespace Poco
@ -43,7 +44,7 @@ namespace DB
struct ContextShared;
class Context;
class QuotaForIntervals;
class QuotaContext;
class EmbeddedDictionaries;
class ExternalDictionariesLoader;
class ExternalModelsLoader;
@ -76,7 +77,9 @@ class ActionLocksManager;
using ActionLocksManagerPtr = std::shared_ptr<ActionLocksManager>;
class ShellCommand;
class ICompressionCodec;
class AccessControlManager;
class SettingsConstraints;
class RemoteHostFilter;
class IOutputFormat;
using OutputFormatPtr = std::shared_ptr<IOutputFormat>;
@ -135,7 +138,8 @@ private:
InputInitializer input_initializer_callback;
InputBlocksReader input_blocks_reader;
std::shared_ptr<QuotaForIntervals> quota; /// Current quota. By default - empty quota, that have no limits.
std::shared_ptr<QuotaContext> quota; /// Current quota. By default - empty quota, that have no limits.
bool is_quota_management_allowed = false; /// Whether the current user is allowed to manage quotas via SQL commands.
String current_database;
Settings settings; /// Setting for query execution.
std::shared_ptr<const SettingsConstraints> settings_constraints;
@ -199,6 +203,11 @@ public:
void setConfig(const ConfigurationPtr & config);
const Poco::Util::AbstractConfiguration & getConfigRef() const;
AccessControlManager & getAccessControlManager();
const AccessControlManager & getAccessControlManager() const;
std::shared_ptr<QuotaContext> getQuota() const { return quota; }
void checkQuotaManagementIsAllowed();
/** Take the list of users, quotas and configuration profiles from this config.
* The list of users is completely replaced.
* The accumulated quota values are not reset if the quota is not deleted.
@ -238,9 +247,6 @@ public:
ClientInfo & getClientInfo() { return client_info; }
const ClientInfo & getClientInfo() const { return client_info; }
void setQuota(const String & name, const String & quota_key, const String & user_name, const Poco::Net::IPAddress & address);
QuotaForIntervals & getQuota();
void addDependency(const DatabaseAndTableName & from, const DatabaseAndTableName & where);
void removeDependency(const DatabaseAndTableName & from, const DatabaseAndTableName & where);
Dependencies getDependencies(const String & database_name, const String & table_name) const;
@ -354,6 +360,10 @@ public:
void setInterserverScheme(const String & scheme);
String getInterserverScheme() const;
/// Storage of allowed hosts from config.xml
void setRemoteHostFilter(const Poco::Util::AbstractConfiguration & config);
const RemoteHostFilter & getRemoteHostFilter() const;
/// The port that the server listens for executing SQL queries.
UInt16 getTCPPort() const;
@ -404,7 +414,6 @@ public:
const Settings & getSettingsRef() const { return settings; }
Settings & getSettingsRef() { return settings; }
void setProgressCallback(ProgressCallback callback);
/// Used in InterpreterSelectQuery to pass it to the IBlockInputStream.
ProgressCallback getProgressCallback() const;

View File

@ -167,7 +167,7 @@ private:
size_t canMoveEqualsToJoinOn(const ASTFunction & node)
{
if (!node.arguments)
throw Exception("Logical error: function requires argiment", ErrorCodes::LOGICAL_ERROR);
throw Exception("Logical error: function requires arguments", ErrorCodes::LOGICAL_ERROR);
if (node.arguments->children.size() != 2)
return false;

View File

@ -2,13 +2,13 @@
#include <mutex>
#include <pcg_random.hpp>
#include <common/DateLUT.h>
#include <Common/Config/AbstractConfigurationComparison.h>
#include <Common/Exception.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/ThreadPool.h>
#include <Common/randomSeed.h>
#include <Common/setThreadName.h>
#include <ext/chrono_io.h>
#include <ext/scope_guard.h>
@ -288,7 +288,7 @@ class ExternalLoader::LoadingDispatcher : private boost::noncopyable
public:
/// Called to load or reload an object.
using CreateObjectFunction = std::function<LoadablePtr(
const String & /* name */, const ObjectConfig & /* config */, bool config_changed, const LoadablePtr & /* previous_version */)>;
const String & /* name */, const ObjectConfig & /* config */, const LoadablePtr & /* previous_version */)>;
LoadingDispatcher(
const CreateObjectFunction & create_object_function_,
@ -560,8 +560,8 @@ public:
/// The function doesn't touch the objects which were never tried to load.
void reloadOutdated()
{
/// Iterate through all the objects and find loaded ones which should be checked if they were modified.
std::unordered_map<LoadablePtr, bool> is_modified_map;
/// Iterate through all the objects and find loaded ones which should be checked if they need update.
std::unordered_map<LoadablePtr, bool> should_update_map;
{
std::lock_guard lock{mutex};
TimePoint now = std::chrono::system_clock::now();
@ -569,22 +569,26 @@ public:
{
const auto & info = name_and_info.second;
if ((now >= info.next_update_time) && !info.loading() && info.loaded())
is_modified_map.emplace(info.object, true);
should_update_map.emplace(info.object, info.failedToReload());
}
}
/// Find out which of the loaded objects were modified.
/// We couldn't perform these checks while we were building `is_modified_map` because
/// We couldn't perform these checks while we were building `should_update_map` because
/// the `mutex` should be unlocked while we're calling the function object->isModified()
for (auto & [object, is_modified_flag] : is_modified_map)
for (auto & [object, should_update_flag] : should_update_map)
{
try
{
is_modified_flag = object->isModified();
/// Maybe alredy true, if we have an exception
if (!should_update_flag)
should_update_flag = object->isModified();
}
catch (...)
{
tryLogCurrentException(log, "Could not check if " + type_name + " '" + object->getName() + "' was modified");
/// Cannot check isModified, so update
should_update_flag = true;
}
}
@ -598,19 +602,18 @@ public:
{
if (info.loaded())
{
auto it = is_modified_map.find(info.object);
if (it == is_modified_map.end())
continue; /// Object has been just loaded (it wasn't loaded while we were building the map `is_modified_map`), so we don't have to reload it right now.
auto it = should_update_map.find(info.object);
if (it == should_update_map.end())
continue; /// Object has been just loaded (it wasn't loaded while we were building the map `should_update_map`), so we don't have to reload it right now.
bool is_modified_flag = it->second;
if (!is_modified_flag)
bool should_update_flag = it->second;
if (!should_update_flag)
{
/// Object wasn't modified so we only have to set `next_update_time`.
info.next_update_time = calculateNextUpdateTime(info.object, info.error_count);
continue;
}
/// Object was modified and should be reloaded.
/// Object was modified or it was failed to reload last time, so it should be reloaded.
startLoading(name, info);
}
else if (info.failed())
@ -633,6 +636,7 @@ private:
bool loading() const { return loading_id != 0; }
bool wasLoading() const { return loaded() || failed() || loading(); }
bool ready() const { return (loaded() || failed()) && !forced_to_reload; }
bool failedToReload() const { return loaded() && exception != nullptr; }
Status status() const
{
@ -787,14 +791,13 @@ private:
std::pair<LoadablePtr, std::exception_ptr> loadOneObject(
const String & name,
const ObjectConfig & config,
bool config_changed,
LoadablePtr previous_version)
{
LoadablePtr new_object;
std::exception_ptr new_exception;
try
{
new_object = create_object(name, config, config_changed, previous_version);
new_object = create_object(name, config, previous_version);
}
catch (...)
{
@ -874,8 +877,7 @@ private:
{
if (next_update_time == TimePoint::max())
return String();
return ", next update is scheduled at "
+ DateLUT::instance().timeToString(std::chrono::system_clock::to_time_t(next_update_time));
return ", next update is scheduled at " + ext::to_string(next_update_time);
};
if (previous_version)
tryLogException(new_exception, log, "Could not update " + type_name + " '" + name + "'"
@ -915,7 +917,8 @@ private:
/// Use `create_function` to perform the actual loading.
/// It's much better to do it with `mutex` unlocked because the loading can take a lot of time
/// and require access to other objects.
auto [new_object, new_exception] = loadOneObject(name, info->object_config, info->config_changed, info->object);
bool need_complete_loading = !info->object || info->config_changed || info->forced_to_reload;
auto [new_object, new_exception] = loadOneObject(name, info->object_config, need_complete_loading ? nullptr : info->object);
if (!new_object && !new_exception)
throw Exception("No object created and no exception raised for " + type_name, ErrorCodes::LOGICAL_ERROR);
@ -1072,7 +1075,7 @@ private:
ExternalLoader::ExternalLoader(const String & type_name_, Logger * log)
: config_files_reader(std::make_unique<LoadablesConfigReader>(type_name_, log))
, loading_dispatcher(std::make_unique<LoadingDispatcher>(
std::bind(&ExternalLoader::createObject, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4),
std::bind(&ExternalLoader::createObject, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3),
type_name_,
log))
, periodic_updater(std::make_unique<PeriodicUpdater>(*config_files_reader, *loading_dispatcher))
@ -1222,9 +1225,9 @@ void ExternalLoader::addObjectAndLoad(
ExternalLoader::LoadablePtr ExternalLoader::createObject(
const String & name, const ObjectConfig & config, bool config_changed, const LoadablePtr & previous_version) const
const String & name, const ObjectConfig & config, const LoadablePtr & previous_version) const
{
if (previous_version && !config_changed)
if (previous_version)
return previous_version->clone();
return create(name, *config.config, config.key_in_config);

View File

@ -27,7 +27,7 @@ struct ExternalLoaderConfigSettings
};
/** Iterface for manage user-defined objects.
/** Interface for manage user-defined objects.
* Monitors configuration file and automatically reloads objects in separate threads.
* The monitoring thread wakes up every 'check_period_sec' seconds and checks
* modification time of objects' configuration file. If said time is greater than
@ -175,7 +175,7 @@ protected:
private:
struct ObjectConfig;
LoadablePtr createObject(const String & name, const ObjectConfig & config, bool config_changed, const LoadablePtr & previous_version) const;
LoadablePtr createObject(const String & name, const ObjectConfig & config, const LoadablePtr & previous_version) const;
class LoadablesConfigReader;
std::unique_ptr<LoadablesConfigReader> config_files_reader;

View File

@ -22,6 +22,9 @@ public:
virtual bool canExecuteWithProcessors() const { return false; }
virtual bool ignoreQuota() const { return false; }
virtual bool ignoreLimits() const { return false; }
virtual ~IInterpreter() {}
};

View File

@ -0,0 +1,121 @@
#include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <ext/range.h>
#include <boost/range/algorithm/find_if.hpp>
#include <boost/range/algorithm/upper_bound.hpp>
#include <boost/range/algorithm/sort.hpp>
namespace DB
{
BlockIO InterpreterCreateQuotaQuery::execute()
{
context.checkQuotaManagementIsAllowed();
const auto & query = query_ptr->as<const ASTCreateQuotaQuery &>();
auto & access_control = context.getAccessControlManager();
if (query.alter)
{
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_quota = typeid_cast<std::shared_ptr<Quota>>(entity->clone());
updateQuotaFromQuery(*updated_quota, query);
return updated_quota;
};
if (query.if_exists)
{
if (auto id = access_control.find<Quota>(query.name))
access_control.tryUpdate(*id, update_func);
}
else
access_control.update(access_control.getID<Quota>(query.name), update_func);
}
else
{
auto new_quota = std::make_shared<Quota>();
updateQuotaFromQuery(*new_quota, query);
if (query.if_not_exists)
access_control.tryInsert(new_quota);
else if (query.or_replace)
access_control.insertOrReplace(new_quota);
else
access_control.insert(new_quota);
}
return {};
}
void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query)
{
if (query.alter)
{
if (!query.new_name.empty())
quota.setName(query.new_name);
}
else
quota.setName(query.name);
if (query.key_type)
quota.key_type = *query.key_type;
auto & quota_all_limits = quota.all_limits;
for (const auto & query_limits : query.all_limits)
{
auto duration = query_limits.duration;
auto it = boost::range::find_if(quota_all_limits, [&](const Quota::Limits & x) { return x.duration == duration; });
if (query_limits.unset_tracking)
{
if (it != quota_all_limits.end())
quota_all_limits.erase(it);
continue;
}
if (it == quota_all_limits.end())
{
/// We keep `all_limits` sorted by duration.
it = quota_all_limits.insert(
boost::range::upper_bound(
quota_all_limits,
duration,
[](const std::chrono::seconds & lhs, const Quota::Limits & rhs) { return lhs < rhs.duration; }),
Quota::Limits{});
it->duration = duration;
}
auto & quota_limits = *it;
quota_limits.randomize_interval = query_limits.randomize_interval;
for (auto resource_type : ext::range(Quota::MAX_RESOURCE_TYPE))
{
if (query_limits.max[resource_type])
quota_limits.max[resource_type] = *query_limits.max[resource_type];
}
}
if (query.roles)
{
const auto & query_roles = *query.roles;
/// We keep `roles` sorted.
quota.roles = query_roles.roles;
if (query_roles.current_user)
quota.roles.push_back(context.getClientInfo().current_user);
boost::range::sort(quota.roles);
quota.roles.erase(std::unique(quota.roles.begin(), quota.roles.end()), quota.roles.end());
quota.all_roles = query_roles.all_roles;
/// We keep `except_roles` sorted.
quota.except_roles = query_roles.except_roles;
if (query_roles.except_current_user)
quota.except_roles.push_back(context.getClientInfo().current_user);
boost::range::sort(quota.except_roles);
quota.except_roles.erase(std::unique(quota.except_roles.begin(), quota.except_roles.end()), quota.except_roles.end());
}
}
}

View File

@ -0,0 +1,29 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTCreateQuotaQuery;
struct Quota;
class InterpreterCreateQuotaQuery : public IInterpreter
{
public:
InterpreterCreateQuotaQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
bool ignoreQuota() const override { return true; }
bool ignoreLimits() const override { return true; }
private:
void updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query);
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -0,0 +1,31 @@
#include <Interpreters/InterpreterDropAccessEntityQuery.h>
#include <Parsers/ASTDropAccessEntityQuery.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/Quota.h>
namespace DB
{
BlockIO InterpreterDropAccessEntityQuery::execute()
{
const auto & query = query_ptr->as<const ASTDropAccessEntityQuery &>();
auto & access_control = context.getAccessControlManager();
using Kind = ASTDropAccessEntityQuery::Kind;
switch (query.kind)
{
case Kind::QUOTA:
{
context.checkQuotaManagementIsAllowed();
if (query.if_exists)
access_control.tryRemove(access_control.find<Quota>(query.names));
else
access_control.remove(access_control.getIDs<Quota>(query.names));
return {};
}
}
__builtin_unreachable();
}
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class InterpreterDropAccessEntityQuery : public IInterpreter
{
public:
InterpreterDropAccessEntityQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -1,6 +1,8 @@
#include <Parsers/ASTAlterQuery.h>
#include <Parsers/ASTCheckQuery.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTDropAccessEntityQuery.h>
#include <Parsers/ASTDropQuery.h>
#include <Parsers/ASTInsertQuery.h>
#include <Parsers/ASTKillQueryQuery.h>
@ -9,7 +11,9 @@
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTShowProcesslistQuery.h>
#include <Parsers/ASTShowQuotasQuery.h>
#include <Parsers/ASTShowTablesQuery.h>
#include <Parsers/ASTUseQuery.h>
#include <Parsers/ASTExplainQuery.h>
@ -19,8 +23,10 @@
#include <Interpreters/InterpreterAlterQuery.h>
#include <Interpreters/InterpreterCheckQuery.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Interpreters/InterpreterDescribeQuery.h>
#include <Interpreters/InterpreterExplainQuery.h>
#include <Interpreters/InterpreterDropAccessEntityQuery.h>
#include <Interpreters/InterpreterDropQuery.h>
#include <Interpreters/InterpreterExistsQuery.h>
#include <Interpreters/InterpreterFactory.h>
@ -31,8 +37,10 @@
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/InterpreterSetQuery.h>
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterShowProcesslistQuery.h>
#include <Interpreters/InterpreterShowQuotasQuery.h>
#include <Interpreters/InterpreterShowTablesQuery.h>
#include <Interpreters/InterpreterSystemQuery.h>
#include <Interpreters/InterpreterUseQuery.h>
@ -187,6 +195,22 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{
return std::make_unique<InterpreterWatchQuery>(query, context);
}
else if (query->as<ASTCreateQuotaQuery>())
{
return std::make_unique<InterpreterCreateQuotaQuery>(query, context);
}
else if (query->as<ASTDropAccessEntityQuery>())
{
return std::make_unique<InterpreterDropAccessEntityQuery>(query, context);
}
else if (query->as<ASTShowCreateAccessEntityQuery>())
{
return std::make_unique<InterpreterShowCreateAccessEntityQuery>(query, context);
}
else if (query->as<ASTShowQuotasQuery>())
{
return std::make_unique<InterpreterShowQuotasQuery>(query, context);
}
else
throw Exception("Unknown type of query: " + query->getID(), ErrorCodes::UNKNOWN_TYPE_OF_QUERY);
}

View File

@ -419,6 +419,17 @@ InterpreterSelectQuery::InterpreterSelectQuery(
/// null non-const columns to avoid useless memory allocations. However, a valid block sample
/// requires all columns to be of size 0, thus we need to sanitize the block here.
sanitizeBlock(result_header);
/// Remove limits for some tables in the `system` database.
if (storage && (storage->getDatabaseName() == "system"))
{
String table_name = storage->getTableName();
if ((table_name == "quotas") || (table_name == "quota_usage") || (table_name == "one"))
{
options.ignore_quota = true;
options.ignore_limits = true;
}
}
}
@ -1776,14 +1787,14 @@ void InterpreterSelectQuery::executeFetchColumns(
limits.speed_limits.timeout_before_checking_execution_speed = settings.timeout_before_checking_execution_speed;
}
QuotaForIntervals & quota = context->getQuota();
auto quota = context->getQuota();
for (auto & stream : streams)
{
if (!options.ignore_limits)
stream->setLimits(limits);
if (options.to_stage == QueryProcessingStage::Complete)
if (!options.ignore_quota && (options.to_stage == QueryProcessingStage::Complete))
stream->setQuota(quota);
}
@ -1793,7 +1804,7 @@ void InterpreterSelectQuery::executeFetchColumns(
if (!options.ignore_limits)
pipe.setLimits(limits);
if (options.to_stage == QueryProcessingStage::Complete)
if (!options.ignore_quota && (options.to_stage == QueryProcessingStage::Complete))
pipe.setQuota(quota);
}
}

View File

@ -74,6 +74,9 @@ public:
QueryPipeline executeWithProcessors() override;
bool canExecuteWithProcessors() const override { return true; }
bool ignoreLimits() const override { return options.ignore_limits; }
bool ignoreQuota() const override { return options.ignore_quota; }
Block getSampleBlock();
void ignoreWithTotals();
@ -260,7 +263,7 @@ private:
*/
void initSettings();
const SelectQueryOptions options;
SelectQueryOptions options;
ASTPtr query_ptr;
std::shared_ptr<Context> context;
SyntaxAnalyzerResultPtr syntax_analyzer_result;

View File

@ -107,6 +107,19 @@ InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
result_header = getCommonHeaderForUnion(headers);
}
/// InterpreterSelectWithUnionQuery ignores limits if all nested interpreters ignore limits.
bool all_nested_ignore_limits = true;
bool all_nested_ignore_quota = true;
for (auto & interpreter : nested_interpreters)
{
if (!interpreter->ignoreLimits())
all_nested_ignore_limits = false;
if (!interpreter->ignoreQuota())
all_nested_ignore_quota = false;
}
options.ignore_limits |= all_nested_ignore_limits;
options.ignore_quota |= all_nested_ignore_quota;
}

View File

@ -34,6 +34,9 @@ public:
QueryPipeline executeWithProcessors() override;
bool canExecuteWithProcessors() const override { return true; }
bool ignoreLimits() const override { return options.ignore_limits; }
bool ignoreQuota() const override { return options.ignore_quota; }
Block getSampleBlock();
static Block getSampleBlock(
@ -45,7 +48,7 @@ public:
ASTPtr getQuery() const { return query_ptr; }
private:
const SelectQueryOptions options;
SelectQueryOptions options;
ASTPtr query_ptr;
std::shared_ptr<Context> context;

View File

@ -0,0 +1,89 @@
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/formatAST.h>
#include <Access/AccessControlManager.h>
#include <Access/QuotaContext.h>
#include <Columns/ColumnString.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataTypes/DataTypeString.h>
#include <ext/range.h>
#include <sstream>
namespace DB
{
BlockIO InterpreterShowCreateAccessEntityQuery::execute()
{
BlockIO res;
res.in = executeImpl();
return res;
}
BlockInputStreamPtr InterpreterShowCreateAccessEntityQuery::executeImpl()
{
const auto & show_query = query_ptr->as<ASTShowCreateAccessEntityQuery &>();
/// Build a create query.
ASTPtr create_query = getCreateQuotaQuery(show_query);
/// Build the result column.
std::stringstream create_query_ss;
formatAST(*create_query, create_query_ss, false, true);
String create_query_str = create_query_ss.str();
MutableColumnPtr column = ColumnString::create();
column->insert(create_query_str);
/// Prepare description of the result column.
std::stringstream desc_ss;
formatAST(show_query, desc_ss, false, true);
String desc = desc_ss.str();
String prefix = "SHOW ";
if (startsWith(desc, prefix))
desc = desc.substr(prefix.length()); /// `desc` always starts with "SHOW ", so we can trim this prefix.
return std::make_shared<OneBlockInputStream>(Block{{std::move(column), std::make_shared<DataTypeString>(), desc}});
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
auto & access_control = context.getAccessControlManager();
QuotaPtr quota;
if (show_query.current_quota)
quota = access_control.read<Quota>(context.getQuota()->getUsageInfo().quota_id);
else
quota = access_control.read<Quota>(show_query.name);
auto create_query = std::make_shared<ASTCreateQuotaQuery>();
create_query->name = quota->getName();
create_query->key_type = quota->key_type;
create_query->all_limits.reserve(quota->all_limits.size());
for (const auto & limits : quota->all_limits)
{
ASTCreateQuotaQuery::Limits create_query_limits;
create_query_limits.duration = limits.duration;
create_query_limits.randomize_interval = limits.randomize_interval;
for (auto resource_type : ext::range(Quota::MAX_RESOURCE_TYPE))
if (limits.max[resource_type])
create_query_limits.max[resource_type] = limits.max[resource_type];
create_query->all_limits.push_back(create_query_limits);
}
if (!quota->roles.empty() || quota->all_roles)
{
auto create_query_roles = std::make_shared<ASTRoleList>();
create_query_roles->roles = quota->roles;
create_query_roles->all_roles = quota->all_roles;
create_query_roles->except_roles = quota->except_roles;
create_query->roles = std::move(create_query_roles);
}
return create_query;
}
}

View File

@ -0,0 +1,35 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class Context;
class ASTShowCreateAccessEntityQuery;
/** Returns a single item containing a statement which could be used to create a specified role.
*/
class InterpreterShowCreateAccessEntityQuery : public IInterpreter
{
public:
InterpreterShowCreateAccessEntityQuery(const ASTPtr & query_ptr_, const Context & context_)
: query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
bool ignoreQuota() const override { return true; }
bool ignoreLimits() const override { return true; }
private:
ASTPtr query_ptr;
const Context & context;
BlockInputStreamPtr executeImpl();
ASTPtr getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
};
}

View File

@ -0,0 +1,73 @@
#include <Interpreters/InterpreterShowQuotasQuery.h>
#include <Interpreters/executeQuery.h>
#include <Parsers/ASTShowQuotasQuery.h>
#include <Parsers/formatAST.h>
#include <Access/Quota.h>
#include <Common/quoteString.h>
#include <Common/StringUtils/StringUtils.h>
#include <ext/range.h>
namespace DB
{
InterpreterShowQuotasQuery::InterpreterShowQuotasQuery(const ASTPtr & query_ptr_, Context & context_)
: query_ptr(query_ptr_), context(context_)
{
}
String InterpreterShowQuotasQuery::getRewrittenQuery()
{
const auto & query = query_ptr->as<ASTShowQuotasQuery &>();
/// Transform the query into some kind of "SELECT from system.quotas" query.
String expr;
String filter;
String table_name;
String order_by;
if (query.usage)
{
expr = "name || ' key=\\'' || key || '\\'' || if(isNull(end_of_interval), '', ' interval=[' || "
"toString(end_of_interval - duration) || ' .. ' || "
"toString(end_of_interval) || ']'";
for (auto resource_type : ext::range_with_static_cast<Quota::ResourceType>(Quota::MAX_RESOURCE_TYPE))
{
String column_name = Quota::resourceTypeToColumnName(resource_type);
expr += String{" || ' "} + column_name + "=' || toString(" + column_name + ")";
expr += String{" || if(max_"} + column_name + "=0, '', '/' || toString(max_" + column_name + "))";
}
expr += ")";
if (query.current)
filter = "(id = currentQuotaID()) AND (key = currentQuotaKey())";
table_name = "system.quota_usage";
order_by = "name, key, duration";
}
else
{
expr = "name";
table_name = "system.quotas";
order_by = "name";
}
/// Prepare description of the result column.
std::stringstream ss;
formatAST(query, ss, false, true);
String desc = ss.str();
String prefix = "SHOW ";
if (startsWith(desc, prefix))
desc = desc.substr(prefix.length()); /// `desc` always starts with "SHOW ", so we can trim this prefix.
/// Build a new query.
return "SELECT " + expr + " AS " + backQuote(desc) + " FROM " + table_name + (filter.empty() ? "" : (" WHERE " + filter))
+ (order_by.empty() ? "" : (" ORDER BY " + order_by));
}
BlockIO InterpreterShowQuotasQuery::execute()
{
return executeQuery(getRewrittenQuery(), context, true);
}
}

View File

@ -0,0 +1,28 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class Context;
class InterpreterShowQuotasQuery : public IInterpreter
{
public:
InterpreterShowQuotasQuery(const ASTPtr & query_ptr_, Context & context_);
BlockIO execute() override;
bool ignoreQuota() const override { return true; }
bool ignoreLimits() const override { return true; }
private:
ASTPtr query_ptr;
Context & context;
String getRewrittenQuery();
};
}

View File

@ -121,7 +121,9 @@ void startStopAction(Context & context, ASTSystemQuery & query, StorageActionBlo
InterpreterSystemQuery::InterpreterSystemQuery(const ASTPtr & query_ptr_, Context & context_)
: query_ptr(query_ptr_->clone()), context(context_), log(&Poco::Logger::get("InterpreterSystemQuery")) {}
: query_ptr(query_ptr_->clone()), context(context_), log(&Poco::Logger::get("InterpreterSystemQuery"))
{
}
BlockIO InterpreterSystemQuery::execute()

View File

@ -20,6 +20,9 @@ public:
BlockIO execute() override;
bool ignoreQuota() const override { return true; }
bool ignoreLimits() const override { return true; }
private:
ASTPtr query_ptr;
Context & context;

View File

@ -193,10 +193,10 @@ static const IColumn * extractAsofColumn(const ColumnRawPtrs & key_columns)
return key_columns.back();
}
template<typename KeyGetter, ASTTableJoin::Strictness STRICTNESS>
template<typename KeyGetter, bool is_asof_join>
static KeyGetter createKeyGetter(const ColumnRawPtrs & key_columns, const Sizes & key_sizes)
{
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof)
if constexpr (is_asof_join)
{
auto key_column_copy = key_columns;
auto key_size_copy = key_sizes;
@ -360,28 +360,19 @@ void Join::setSampleBlock(const Block & block)
namespace
{
/// Inserting an element into a hash table of the form `key -> reference to a string`, which will then be used by JOIN.
template <ASTTableJoin::Strictness STRICTNESS, typename Map, typename KeyGetter>
template <typename Map, typename KeyGetter>
struct Inserter
{
static void insert(const Join &, Map & map, KeyGetter & key_getter, Block * stored_block, size_t i, Arena & pool);
};
template <typename Map, typename KeyGetter>
struct Inserter<ASTTableJoin::Strictness::Any, Map, KeyGetter>
{
static ALWAYS_INLINE void insert(const Join & join, Map & map, KeyGetter & key_getter, Block * stored_block, size_t i, Arena & pool)
static ALWAYS_INLINE void insertOne(const Join & join, Map & map, KeyGetter & key_getter, Block * stored_block, size_t i,
Arena & pool)
{
auto emplace_result = key_getter.emplaceKey(map, i, pool);
if (emplace_result.isInserted() || join.anyTakeLastRow())
new (&emplace_result.getMapped()) typename Map::mapped_type(stored_block, i);
}
};
template <typename Map, typename KeyGetter>
struct Inserter<ASTTableJoin::Strictness::All, Map, KeyGetter>
{
static ALWAYS_INLINE void insert(const Join &, Map & map, KeyGetter & key_getter, Block * stored_block, size_t i, Arena & pool)
static ALWAYS_INLINE void insertAll(const Join &, Map & map, KeyGetter & key_getter, Block * stored_block, size_t i, Arena & pool)
{
auto emplace_result = key_getter.emplaceKey(map, i, pool);
@ -393,13 +384,9 @@ namespace
emplace_result.getMapped().insert({stored_block, i}, pool);
}
}
};
template <typename Map, typename KeyGetter>
struct Inserter<ASTTableJoin::Strictness::Asof, Map, KeyGetter>
{
static ALWAYS_INLINE void insert(Join & join, Map & map, KeyGetter & key_getter, Block * stored_block, size_t i, Arena & pool,
const IColumn * asof_column)
static ALWAYS_INLINE void insertAsof(Join & join, Map & map, KeyGetter & key_getter, Block * stored_block, size_t i, Arena & pool,
const IColumn * asof_column)
{
auto emplace_result = key_getter.emplaceKey(map, i, pool);
typename Map::mapped_type * time_series_map = &emplace_result.getMapped();
@ -416,21 +403,27 @@ namespace
Join & join, Map & map, size_t rows, const ColumnRawPtrs & key_columns,
const Sizes & key_sizes, Block * stored_block, ConstNullMapPtr null_map, Arena & pool)
{
constexpr bool mapped_one = std::is_same_v<typename Map::mapped_type, JoinStuff::MappedOne> ||
std::is_same_v<typename Map::mapped_type, JoinStuff::MappedOneFlagged>;
constexpr bool is_asof_join = STRICTNESS == ASTTableJoin::Strictness::Asof;
const IColumn * asof_column [[maybe_unused]] = nullptr;
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof)
if constexpr (is_asof_join)
asof_column = extractAsofColumn(key_columns);
auto key_getter = createKeyGetter<KeyGetter, STRICTNESS>(key_columns, key_sizes);
auto key_getter = createKeyGetter<KeyGetter, is_asof_join>(key_columns, key_sizes);
for (size_t i = 0; i < rows; ++i)
{
if (has_null_map && (*null_map)[i])
continue;
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof)
Inserter<STRICTNESS, Map, KeyGetter>::insert(join, map, key_getter, stored_block, i, pool, asof_column);
if constexpr (is_asof_join)
Inserter<Map, KeyGetter>::insertAsof(join, map, key_getter, stored_block, i, pool, asof_column);
else if constexpr (mapped_one)
Inserter<Map, KeyGetter>::insertOne(join, map, key_getter, stored_block, i, pool);
else
Inserter<STRICTNESS, Map, KeyGetter>::insert(join, map, key_getter, stored_block, i, pool);
Inserter<Map, KeyGetter>::insertAll(join, map, key_getter, stored_block, i, pool);
}
}
@ -508,7 +501,7 @@ void Join::initRightBlockStructure()
JoinCommon::convertColumnsToNullable(saved_block_sample, (isFull(kind) ? right_table_keys.columns() : 0));
}
Block * Join::storeRightBlock(const Block & source_block)
Block Join::structureRightBlock(const Block & source_block) const
{
/// Rare case, when joined columns are constant. To avoid code bloat, simply materialize them.
Block block = materializeBlock(source_block);
@ -522,14 +515,11 @@ Block * Join::storeRightBlock(const Block & source_block)
structured_block.insert(column);
}
blocks.push_back(structured_block);
return &blocks.back();
return structured_block;
}
bool Join::addJoinedBlock(const Block & block)
{
std::unique_lock lock(rwlock);
if (empty())
throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR);
@ -541,32 +531,45 @@ bool Join::addJoinedBlock(const Block & block)
ConstNullMapPtr null_map{};
ColumnPtr null_map_holder = extractNestedColumnsAndNullMap(key_columns, null_map);
size_t rows = block.rows();
if (rows)
has_no_rows_in_maps = false;
Block * stored_block = storeRightBlock(block);
if (kind != ASTTableJoin::Kind::Cross)
{
joinDispatch(kind, strictness, maps, [&](auto, auto strictness_, auto & map)
{
insertFromBlockImpl<strictness_>(*this, type, map, rows, key_columns, key_sizes, stored_block, null_map, pool);
});
}
/// If RIGHT or FULL save blocks with nulls for NonJoinedBlockInputStream
UInt8 save_nullmap = 0;
if (isRightOrFull(kind) && null_map)
{
UInt8 has_null = 0;
for (size_t i = 0; !has_null && i < null_map->size(); ++i)
has_null |= (*null_map)[i];
if (has_null)
blocks_nullmaps.emplace_back(stored_block, null_map_holder);
for (size_t i = 0; !save_nullmap && i < null_map->size(); ++i)
save_nullmap |= (*null_map)[i];
}
return table_join->sizeLimits().check(getTotalRowCount(), getTotalByteCount(), "JOIN", ErrorCodes::SET_SIZE_LIMIT_EXCEEDED);
Block structured_block = structureRightBlock(block);
size_t total_rows = 0;
size_t total_bytes = 0;
{
std::unique_lock lock(rwlock);
blocks.emplace_back(std::move(structured_block));
Block * stored_block = &blocks.back();
size_t rows = block.rows();
if (rows)
has_no_rows_in_maps = false;
if (kind != ASTTableJoin::Kind::Cross)
{
joinDispatch(kind, strictness, maps, [&](auto, auto strictness_, auto & map)
{
insertFromBlockImpl<strictness_>(*this, type, map, rows, key_columns, key_sizes, stored_block, null_map, pool);
});
}
if (save_nullmap)
blocks_nullmaps.emplace_back(stored_block, null_map_holder);
/// TODO: Do not calculate them every time
total_rows = getTotalRowCount();
total_bytes = getTotalByteCount();
}
return table_join->sizeLimits().check(total_rows, total_bytes, "JOIN", ErrorCodes::SET_SIZE_LIMIT_EXCEEDED);
}
@ -582,7 +585,15 @@ public:
const Block & block_with_columns_to_add,
const Block & block,
const Block & saved_block_sample,
const ColumnsWithTypeAndName & extras)
const ColumnsWithTypeAndName & extras,
const Join & join_,
const ColumnRawPtrs & key_columns_,
const Sizes & key_sizes_)
: join(join_)
, key_columns(key_columns_)
, key_sizes(key_sizes_)
, rows_to_add(block.rows())
, need_filter(false)
{
size_t num_columns_to_add = sample_block_with_columns_to_add.columns();
@ -613,23 +624,43 @@ public:
return ColumnWithTypeAndName(std::move(columns[i]), type_name[i].first, type_name[i].second);
}
template <bool has_defaults>
void appendFromBlock(const Block & block, size_t row_num)
{
if constexpr (has_defaults)
applyLazyDefaults();
for (size_t j = 0; j < right_indexes.size(); ++j)
columns[j]->insertFrom(*block.getByPosition(right_indexes[j]).column, row_num);
}
void appendDefaultRow()
{
for (size_t j = 0; j < right_indexes.size(); ++j)
columns[j]->insertDefault();
++lazy_defaults_count;
}
void applyLazyDefaults()
{
if (lazy_defaults_count)
{
for (size_t j = 0; j < right_indexes.size(); ++j)
columns[j]->insertManyDefaults(lazy_defaults_count);
lazy_defaults_count = 0;
}
}
const Join & join;
const ColumnRawPtrs & key_columns;
const Sizes & key_sizes;
size_t rows_to_add;
std::unique_ptr<IColumn::Offsets> offsets_to_replicate;
bool need_filter;
private:
TypeAndNames type_name;
MutableColumns columns;
std::vector<size_t> right_indexes;
size_t lazy_defaults_count = 0;
void addColumn(const ColumnWithTypeAndName & src_column)
{
@ -639,131 +670,190 @@ private:
}
};
template <ASTTableJoin::Strictness STRICTNESS, typename Map>
void addFoundRow(const typename Map::mapped_type & mapped, AddedColumns & added, IColumn::Offset & current_offset [[maybe_unused]])
template <typename Map, bool add_missing>
void addFoundRowAll(const typename Map::mapped_type & mapped, AddedColumns & added, IColumn::Offset & current_offset)
{
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any)
{
added.appendFromBlock(*mapped.block, mapped.row_num);
}
if constexpr (add_missing)
added.applyLazyDefaults();
if constexpr (STRICTNESS == ASTTableJoin::Strictness::All)
for (auto it = mapped.begin(); it.ok(); ++it)
{
for (auto it = mapped.begin(); it.ok(); ++it)
{
added.appendFromBlock(*it->block, it->row_num);
++current_offset;
}
added.appendFromBlock<false>(*it->block, it->row_num);
++current_offset;
}
};
template <bool _add_missing>
template <bool add_missing, bool need_offset>
void addNotFoundRow(AddedColumns & added [[maybe_unused]], IColumn::Offset & current_offset [[maybe_unused]])
{
if constexpr (_add_missing)
if constexpr (add_missing)
{
added.appendDefaultRow();
++current_offset;
if constexpr (need_offset)
++current_offset;
}
}
template <bool need_filter>
void setUsed(IColumn::Filter & filter [[maybe_unused]], size_t pos [[maybe_unused]])
{
if constexpr (need_filter)
filter[pos] = 1;
}
/// Joins right table columns which indexes are present in right_indexes using specified map.
/// Makes filter (1 if row presented in right table) and returns offsets to replicate (for ALL JOINS).
template <bool _add_missing, ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map, bool _has_null_map>
std::unique_ptr<IColumn::Offsets> NO_INLINE joinRightIndexedColumns(
const Join & join, const Map & map, size_t rows, const ColumnRawPtrs & key_columns, const Sizes & key_sizes,
AddedColumns & added_columns, ConstNullMapPtr null_map, IColumn::Filter & filter)
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map, bool need_filter, bool has_null_map>
NO_INLINE IColumn::Filter joinRightColumns(const Map & map, AddedColumns & added_columns, const ConstNullMapPtr & null_map [[maybe_unused]])
{
std::unique_ptr<IColumn::Offsets> offsets_to_replicate;
if constexpr (STRICTNESS == ASTTableJoin::Strictness::All)
offsets_to_replicate = std::make_unique<IColumn::Offsets>(rows);
constexpr bool is_any_join = STRICTNESS == ASTTableJoin::Strictness::Any;
constexpr bool is_all_join = STRICTNESS == ASTTableJoin::Strictness::All;
constexpr bool is_asof_join = STRICTNESS == ASTTableJoin::Strictness::Asof;
constexpr bool is_semi_join = STRICTNESS == ASTTableJoin::Strictness::Semi;
constexpr bool is_anti_join = STRICTNESS == ASTTableJoin::Strictness::Anti;
constexpr bool left = KIND == ASTTableJoin::Kind::Left;
constexpr bool right = KIND == ASTTableJoin::Kind::Right;
constexpr bool full = KIND == ASTTableJoin::Kind::Full;
constexpr bool add_missing = (left || full) && !is_semi_join;
constexpr bool need_replication = is_all_join || (is_any_join && right) || (is_semi_join && right);
size_t rows = added_columns.rows_to_add;
IColumn::Filter filter;
if constexpr (need_filter)
filter = IColumn::Filter(rows, 0);
Arena pool;
const IColumn * asof_column [[maybe_unused]] = nullptr;
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof)
asof_column = extractAsofColumn(key_columns);
auto key_getter = createKeyGetter<KeyGetter, STRICTNESS>(key_columns, key_sizes);
if constexpr (need_replication)
added_columns.offsets_to_replicate = std::make_unique<IColumn::Offsets>(rows);
const IColumn * asof_column [[maybe_unused]] = nullptr;
if constexpr (is_asof_join)
asof_column = extractAsofColumn(added_columns.key_columns);
auto key_getter = createKeyGetter<KeyGetter, is_asof_join>(added_columns.key_columns, added_columns.key_sizes);
IColumn::Offset current_offset = 0;
for (size_t i = 0; i < rows; ++i)
{
if (_has_null_map && (*null_map)[i])
if constexpr (has_null_map)
{
addNotFoundRow<_add_missing>(added_columns, current_offset);
if ((*null_map)[i])
{
addNotFoundRow<add_missing, need_replication>(added_columns, current_offset);
if constexpr (need_replication)
(*added_columns.offsets_to_replicate)[i] = current_offset;
continue;
}
}
auto find_result = key_getter.findKey(map, i, pool);
if (find_result.isFound())
{
auto & mapped = find_result.getMapped();
if constexpr (is_asof_join)
{
const Join & join = added_columns.join;
if (const RowRef * found = mapped.findAsof(join.getAsofType(), join.getAsofInequality(), asof_column, i))
{
setUsed<need_filter>(filter, i);
mapped.setUsed();
added_columns.appendFromBlock<add_missing>(*found->block, found->row_num);
}
else
addNotFoundRow<add_missing, need_replication>(added_columns, current_offset);
}
else if constexpr (is_all_join)
{
setUsed<need_filter>(filter, i);
mapped.setUsed();
addFoundRowAll<Map, add_missing>(mapped, added_columns, current_offset);
}
else if constexpr ((is_any_join || is_semi_join) && right)
{
/// Use first appered left key + it needs left columns replication
if (mapped.setUsedOnce())
{
setUsed<need_filter>(filter, i);
addFoundRowAll<Map, add_missing>(mapped, added_columns, current_offset);
}
}
else if constexpr (is_any_join && KIND == ASTTableJoin::Kind::Inner)
{
/// Use first appered left key only
if (mapped.setUsedOnce())
{
setUsed<need_filter>(filter, i);
added_columns.appendFromBlock<add_missing>(*mapped.block, mapped.row_num);
}
}
else if constexpr (is_any_join && full)
{
/// TODO
}
else if constexpr (is_anti_join)
{
if constexpr (right)
mapped.setUsed();
}
else /// ANY LEFT, SEMI LEFT, old ANY (RightAny)
{
setUsed<need_filter>(filter, i);
mapped.setUsed();
added_columns.appendFromBlock<add_missing>(*mapped.block, mapped.row_num);
}
}
else
{
auto find_result = key_getter.findKey(map, i, pool);
if (find_result.isFound())
{
auto & mapped = find_result.getMapped();
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof)
{
if (const RowRef * found = mapped.findAsof(join.getAsofType(), join.getAsofInequality(), asof_column, i))
{
filter[i] = 1;
mapped.setUsed();
added_columns.appendFromBlock(*found->block, found->row_num);
}
else
addNotFoundRow<_add_missing>(added_columns, current_offset);
}
else
{
filter[i] = 1;
mapped.setUsed();
addFoundRow<STRICTNESS, Map>(mapped, added_columns, current_offset);
}
}
else
addNotFoundRow<_add_missing>(added_columns, current_offset);
if constexpr (is_anti_join && left)
setUsed<need_filter>(filter, i);
addNotFoundRow<add_missing, need_replication>(added_columns, current_offset);
}
if constexpr (STRICTNESS == ASTTableJoin::Strictness::All)
(*offsets_to_replicate)[i] = current_offset;
if constexpr (need_replication)
(*added_columns.offsets_to_replicate)[i] = current_offset;
}
return offsets_to_replicate;
}
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map>
IColumn::Filter joinRightColumns(
const Join & join, const Map & map, size_t rows, const ColumnRawPtrs & key_columns, const Sizes & key_sizes,
AddedColumns & added_columns, ConstNullMapPtr null_map, std::unique_ptr<IColumn::Offsets> & offsets_to_replicate)
{
constexpr bool left_or_full = static_in_v<KIND, ASTTableJoin::Kind::Left, ASTTableJoin::Kind::Full>;
IColumn::Filter filter(rows, 0);
if (null_map)
offsets_to_replicate = joinRightIndexedColumns<left_or_full, STRICTNESS, KeyGetter, Map, true>(
join, map, rows, key_columns, key_sizes, added_columns, null_map, filter);
else
offsets_to_replicate = joinRightIndexedColumns<left_or_full, STRICTNESS, KeyGetter, Map, false>(
join, map, rows, key_columns, key_sizes, added_columns, null_map, filter);
added_columns.applyLazyDefaults();
return filter;
}
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map>
IColumn::Filter joinRightColumnsSwitchNullability(const Map & map, AddedColumns & added_columns, const ConstNullMapPtr & null_map)
{
if (added_columns.need_filter)
{
if (null_map)
return joinRightColumns<KIND, STRICTNESS, KeyGetter, Map, true, true>(map, added_columns, null_map);
else
return joinRightColumns<KIND, STRICTNESS, KeyGetter, Map, true, false>(map, added_columns, nullptr);
}
else
{
if (null_map)
return joinRightColumns<KIND, STRICTNESS, KeyGetter, Map, false, true>(map, added_columns, null_map);
else
return joinRightColumns<KIND, STRICTNESS, KeyGetter, Map, false, false>(map, added_columns, nullptr);
}
}
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
IColumn::Filter switchJoinRightColumns(
Join::Type type, const Join & join,
const Maps & maps_, size_t rows, const ColumnRawPtrs & key_columns, const Sizes & key_sizes,
AddedColumns & added_columns, ConstNullMapPtr null_map,
std::unique_ptr<IColumn::Offsets> & offsets_to_replicate)
IColumn::Filter switchJoinRightColumns(const Maps & maps_, AddedColumns & added_columns, Join::Type type, const ConstNullMapPtr & null_map)
{
switch (type)
{
#define M(TYPE) \
case Join::Type::TYPE: \
return joinRightColumns<KIND, STRICTNESS, typename KeyGetterForType<Join::Type::TYPE, const std::remove_reference_t<decltype(*maps_.TYPE)>>::Type>(\
join, *maps_.TYPE, rows, key_columns, key_sizes, added_columns, null_map, offsets_to_replicate);
return joinRightColumnsSwitchNullability<KIND, STRICTNESS,\
typename KeyGetterForType<Join::Type::TYPE, const std::remove_reference_t<decltype(*maps_.TYPE)>>::Type>(\
*maps_.TYPE, added_columns, null_map);\
break;
APPLY_FOR_JOIN_VARIANTS(M)
#undef M
@ -782,6 +872,20 @@ void Join::joinBlockImpl(
const Block & block_with_columns_to_add,
const Maps & maps_) const
{
constexpr bool is_any_join = STRICTNESS == ASTTableJoin::Strictness::Any;
constexpr bool is_all_join = STRICTNESS == ASTTableJoin::Strictness::All;
constexpr bool is_asof_join = STRICTNESS == ASTTableJoin::Strictness::Asof;
constexpr bool is_semi_join = STRICTNESS == ASTTableJoin::Strictness::Semi;
constexpr bool is_anti_join = STRICTNESS == ASTTableJoin::Strictness::Anti;
constexpr bool left = KIND == ASTTableJoin::Kind::Left;
constexpr bool right = KIND == ASTTableJoin::Kind::Right;
constexpr bool inner = KIND == ASTTableJoin::Kind::Inner;
constexpr bool full = KIND == ASTTableJoin::Kind::Full;
constexpr bool need_replication = is_all_join || (is_any_join && right) || (is_semi_join && right);
constexpr bool need_filter = !need_replication && (inner || right || (is_semi_join && left) || (is_anti_join && left));
/// Rare case, when keys are constant. To avoid code bloat, simply materialize them.
Columns materialized_columns;
ColumnRawPtrs key_columns = JoinCommon::temporaryMaterializeColumns(block, key_names_left, materialized_columns);
@ -796,8 +900,7 @@ void Join::joinBlockImpl(
* Because if they are constants, then in the "not joined" rows, they may have different values
* - default values, which can differ from the values of these constants.
*/
constexpr bool right_or_full = static_in_v<KIND, ASTTableJoin::Kind::Right, ASTTableJoin::Kind::Full>;
if constexpr (right_or_full)
if constexpr (right || full)
{
materializeBlockInplace(block);
@ -811,25 +914,22 @@ void Join::joinBlockImpl(
* For ASOF, the last column is used as the ASOF column
*/
ColumnsWithTypeAndName extras;
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof)
if constexpr (is_asof_join)
extras.push_back(right_table_keys.getByName(key_names_right.back()));
AddedColumns added(sample_block_with_columns_to_add, block_with_columns_to_add, block, saved_block_sample, extras);
std::unique_ptr<IColumn::Offsets> offsets_to_replicate;
AddedColumns added_columns(sample_block_with_columns_to_add, block_with_columns_to_add, block, saved_block_sample,
extras, *this, key_columns, key_sizes);
bool has_required_right_keys = (required_right_keys.columns() != 0);
added_columns.need_filter = need_filter || has_required_right_keys;
IColumn::Filter row_filter = switchJoinRightColumns<KIND, STRICTNESS>(
type, *this, maps_, block.rows(), key_columns, key_sizes, added, null_map, offsets_to_replicate);
IColumn::Filter row_filter = switchJoinRightColumns<KIND, STRICTNESS>(maps_, added_columns, type, null_map);
for (size_t i = 0; i < added.size(); ++i)
block.insert(added.moveColumn(i));
/// Filter & insert missing rows
constexpr bool is_all_join = STRICTNESS == ASTTableJoin::Strictness::All;
constexpr bool inner_or_right = static_in_v<KIND, ASTTableJoin::Kind::Inner, ASTTableJoin::Kind::Right>;
for (size_t i = 0; i < added_columns.size(); ++i)
block.insert(added_columns.moveColumn(i));
std::vector<size_t> right_keys_to_replicate [[maybe_unused]];
if constexpr (!is_all_join && inner_or_right)
if constexpr (need_filter)
{
/// If ANY INNER | RIGHT JOIN - filter all the columns except the new ones.
for (size_t i = 0; i < existing_columns; ++i)
@ -846,7 +946,7 @@ void Join::joinBlockImpl(
block.insert(correctNullability({col.column, col.type, right_key.name}, is_nullable));
}
}
else
else if (has_required_right_keys)
{
/// Some trash to represent IColumn::Filter as ColumnUInt8 needed for ColumnNullable::applyNullMap()
auto null_map_filter_ptr = ColumnUInt8::create();
@ -866,15 +966,14 @@ void Join::joinBlockImpl(
ColumnPtr thin_column = filterWithBlanks(col.column, filter);
block.insert(correctNullability({thin_column, col.type, right_key.name}, is_nullable, null_map_filter));
if constexpr (is_all_join)
if constexpr (need_replication)
right_keys_to_replicate.push_back(block.getPositionByName(right_key.name));
}
}
if constexpr (is_all_join)
if constexpr (need_replication)
{
if (!offsets_to_replicate)
throw Exception("No data to filter columns", ErrorCodes::LOGICAL_ERROR);
std::unique_ptr<IColumn::Offsets> & offsets_to_replicate = added_columns.offsets_to_replicate;
/// If ALL ... JOIN - we replicate all the columns except the new ones.
for (size_t i = 0; i < existing_columns; ++i)
@ -964,7 +1063,7 @@ DataTypePtr Join::joinGetReturnType(const String & column_name) const
template <typename Maps>
void Join::joinGetImpl(Block & block, const String & column_name, const Maps & maps_) const
{
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::RightAny>(
block, {block.getByPosition(0).name}, {sample_block_with_columns_to_add.getByName(column_name)}, maps_);
}
@ -981,9 +1080,10 @@ void Join::joinGet(Block & block, const String & column_name) const
checkTypeOfKey(block, right_table_keys);
if (kind == ASTTableJoin::Kind::Left && strictness == ASTTableJoin::Strictness::Any)
if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) &&
kind == ASTTableJoin::Kind::Left)
{
joinGetImpl(block, column_name, std::get<MapsAny>(maps));
joinGetImpl(block, column_name, std::get<MapsOne>(maps));
}
else
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::LOGICAL_ERROR);
@ -1017,50 +1117,44 @@ void Join::joinTotals(Block & block) const
}
template <ASTTableJoin::Strictness STRICTNESS, typename Mapped>
struct AdderNonJoined;
template <typename Mapped>
struct AdderNonJoined<ASTTableJoin::Strictness::Any, Mapped>
struct AdderNonJoined
{
static void add(const Mapped & mapped, size_t & rows_added, MutableColumns & columns_right)
{
for (size_t j = 0; j < columns_right.size(); ++j)
constexpr bool mapped_asof = std::is_same_v<Mapped, JoinStuff::MappedAsof>;
constexpr bool mapped_one = std::is_same_v<Mapped, JoinStuff::MappedOne> || std::is_same_v<Mapped, JoinStuff::MappedOneFlagged>;
if constexpr (mapped_asof)
{
const auto & mapped_column = mapped.block->getByPosition(j).column;
columns_right[j]->insertFrom(*mapped_column, mapped.row_num);
/// Do nothing
}
++rows_added;
}
};
template <typename Mapped>
struct AdderNonJoined<ASTTableJoin::Strictness::All, Mapped>
{
static void add(const Mapped & mapped, size_t & rows_added, MutableColumns & columns_right)
{
for (auto it = mapped.begin(); it.ok(); ++it)
else if constexpr (mapped_one)
{
for (size_t j = 0; j < columns_right.size(); ++j)
{
const auto & mapped_column = it->block->getByPosition(j).column;
columns_right[j]->insertFrom(*mapped_column, it->row_num);
const auto & mapped_column = mapped.block->getByPosition(j).column;
columns_right[j]->insertFrom(*mapped_column, mapped.row_num);
}
++rows_added;
}
else
{
for (auto it = mapped.begin(); it.ok(); ++it)
{
for (size_t j = 0; j < columns_right.size(); ++j)
{
const auto & mapped_column = it->block->getByPosition(j).column;
columns_right[j]->insertFrom(*mapped_column, it->row_num);
}
++rows_added;
}
}
}
};
template <typename Mapped>
struct AdderNonJoined<ASTTableJoin::Strictness::Asof, Mapped>
{
static void add(const Mapped & /*mapped*/, size_t & /*rows_added*/, MutableColumns & /*columns_right*/)
{
// If we have a leftover match in the right hand side, not required to join because we are only support asof left/inner
}
};
/// Stream from not joined earlier rows of the right table.
class NonJoinedBlockInputStream : public IBlockInputStream
@ -1269,10 +1363,11 @@ private:
for (; it != end; ++it)
{
const Mapped & mapped = it->getMapped();
if (mapped.getUsed())
continue;
AdderNonJoined<STRICTNESS, Mapped>::add(mapped, rows_added, columns_keys_and_right);
AdderNonJoined<Mapped>::add(mapped, rows_added, columns_keys_and_right);
if (rows_added >= max_block_size)
{
@ -1312,6 +1407,10 @@ private:
BlockInputStreamPtr Join::createStreamWithNonJoinedRows(const Block & result_sample_block, UInt64 max_block_size) const
{
if (table_join->strictness() == ASTTableJoin::Strictness::Asof ||
table_join->strictness() == ASTTableJoin::Strictness::Semi)
return {};
if (isRightOrFull(table_join->kind()))
return std::make_shared<NonJoinedBlockInputStream>(*this, result_sample_block, max_block_size);
return {};

View File

@ -44,6 +44,16 @@ struct WithFlags<T, true> : T
mutable std::atomic<bool> used {};
void setUsed() const { used.store(true, std::memory_order_relaxed); } /// Could be set simultaneously from different threads.
bool getUsed() const { return used; }
bool setUsedOnce() const
{
/// fast check to prevent heavy CAS with seq_cst order
if (used.load(std::memory_order_relaxed))
return false;
bool expected = false;
return used.compare_exchange_strong(expected, true);
}
};
template <typename T>
@ -54,13 +64,14 @@ struct WithFlags<T, false> : T
void setUsed() const {}
bool getUsed() const { return true; }
bool setUsedOnce() const { return true; }
};
using MappedAny = WithFlags<RowRef, false>;
using MappedAll = WithFlags<RowRefList, false>;
using MappedAnyFull = WithFlags<RowRef, true>;
using MappedAllFull = WithFlags<RowRefList, true>;
using MappedAsof = WithFlags<AsofRowRefs, false>;
using MappedOne = WithFlags<RowRef, false>;
using MappedAll = WithFlags<RowRefList, false>;
using MappedOneFlagged = WithFlags<RowRef, true>;
using MappedAllFlagged = WithFlags<RowRefList, true>;
using MappedAsof = WithFlags<AsofRowRefs, false>;
}
@ -68,11 +79,23 @@ using MappedAsof = WithFlags<AsofRowRefs, false>;
* It is just a hash table: keys -> rows of joined ("right") table.
* Additionally, CROSS JOIN is supported: instead of hash table, it use just set of blocks without keys.
*
* JOIN-s could be of nine types: ANY/ALL × LEFT/INNER/RIGHT/FULL, and also CROSS.
* JOIN-s could be of these types:
* - ALL × LEFT/INNER/RIGHT/FULL
* - ANY × LEFT/INNER/RIGHT
* - SEMI/ANTI x LEFT/RIGHT
* - ASOF x LEFT/INNER
* - CROSS
*
* If ANY is specified - then select only one row from the "right" table, (first encountered row), even if there was more matching rows.
* If ALL is specified - usual JOIN, when rows are multiplied by number of matching rows from the "right" table.
* ANY is more efficient.
* ALL means usual JOIN, when rows are multiplied by number of matching rows from the "right" table.
* ANY uses one line per unique key from right talbe. For LEFT JOIN it would be any row (with needed joined key) from the right table,
* for RIGHT JOIN it would be any row from the left table and for INNER one it would be any row from right and any row from left.
* SEMI JOIN filter left table by keys that are present in right table for LEFT JOIN, and filter right table by keys from left table
* for RIGHT JOIN. In other words SEMI JOIN returns only rows which joining keys present in another table.
* ANTI JOIN is the same as SEMI JOIN but returns rows with joining keys that are NOT present in another table.
* SEMI/ANTI JOINs allow to get values from both tables. For filter table it gets any row with joining same key. For ANTI JOIN it returns
* defaults other table columns.
* ASOF JOIN is not-equi join. For one key column it finds nearest value to join according to join inequality.
* It's expected that ANY|SEMI LEFT JOIN is more efficient that ALL one.
*
* If INNER is specified - leave only rows that have matching rows from "right" table.
* If LEFT is specified - in case when there is no matching row in "right" table, fill it with default values instead.
@ -264,13 +287,13 @@ public:
}
};
using MapsAny = MapsTemplate<JoinStuff::MappedAny>;
using MapsOne = MapsTemplate<JoinStuff::MappedOne>;
using MapsAll = MapsTemplate<JoinStuff::MappedAll>;
using MapsAnyFull = MapsTemplate<JoinStuff::MappedAnyFull>;
using MapsAllFull = MapsTemplate<JoinStuff::MappedAllFull>;
using MapsOneFlagged = MapsTemplate<JoinStuff::MappedOneFlagged>;
using MapsAllFlagged = MapsTemplate<JoinStuff::MappedAllFlagged>;
using MapsAsof = MapsTemplate<JoinStuff::MappedAsof>;
using MapsVariant = std::variant<MapsAny, MapsAll, MapsAnyFull, MapsAllFull, MapsAsof>;
using MapsVariant = std::variant<MapsOne, MapsAll, MapsOneFlagged, MapsAllFlagged, MapsAsof>;
private:
friend class NonJoinedBlockInputStream;
@ -341,8 +364,8 @@ private:
*/
void setSampleBlock(const Block & block);
/// Modify (structure) and save right block, @returns pointer to saved block
Block * storeRightBlock(const Block & stored_block);
/// Modify (structure) right block to save it in block list
Block structureRightBlock(const Block & stored_block) const;
void initRightBlockStructure();
void initRequiredRightKeys();

View File

@ -1,345 +0,0 @@
#include <iomanip>
#include <common/logger_useful.h>
#include <Common/SipHash.h>
#include <Common/StringUtils/StringUtils.h>
#include <IO/ReadHelpers.h>
#include <Interpreters/Quota.h>
#include <set>
#include <random>
namespace DB
{
namespace ErrorCodes
{
extern const int QUOTA_EXPIRED;
extern const int QUOTA_DOESNT_ALLOW_KEYS;
extern const int UNKNOWN_QUOTA;
}
template <typename Counter>
void QuotaValues<Counter>::initFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config)
{
queries = config.getUInt64(config_elem + ".queries", 0);
errors = config.getUInt64(config_elem + ".errors", 0);
result_rows = config.getUInt64(config_elem + ".result_rows", 0);
result_bytes = config.getUInt64(config_elem + ".result_bytes", 0);
read_rows = config.getUInt64(config_elem + ".read_rows", 0);
read_bytes = config.getUInt64(config_elem + ".read_bytes", 0);
execution_time_usec = config.getUInt64(config_elem + ".execution_time", 0) * 1000000ULL;
}
template void QuotaValues<size_t>::initFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config);
template void QuotaValues<std::atomic<size_t>>::initFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config);
void QuotaForInterval::initFromConfig(
const String & config_elem, time_t duration_, bool randomize_, time_t offset_, const Poco::Util::AbstractConfiguration & config)
{
rounded_time.store(0, std::memory_order_relaxed);
duration = duration_;
randomize = randomize_;
offset = offset_;
max.initFromConfig(config_elem, config);
}
void QuotaForInterval::checkExceeded(time_t current_time, const String & quota_name, const String & user_name)
{
updateTime(current_time);
check(max.queries, used.queries, quota_name, user_name, "Queries");
check(max.errors, used.errors, quota_name, user_name, "Errors");
check(max.result_rows, used.result_rows, quota_name, user_name, "Total result rows");
check(max.result_bytes, used.result_bytes, quota_name, user_name, "Total result bytes");
check(max.read_rows, used.read_rows, quota_name, user_name, "Total rows read");
check(max.read_bytes, used.read_bytes, quota_name, user_name, "Total bytes read");
check(max.execution_time_usec / 1000000, used.execution_time_usec / 1000000, quota_name, user_name, "Total execution time");
}
String QuotaForInterval::toString() const
{
std::stringstream res;
auto loaded_rounded_time = rounded_time.load(std::memory_order_relaxed);
res << std::fixed << std::setprecision(3)
<< "Interval: " << LocalDateTime(loaded_rounded_time) << " - " << LocalDateTime(loaded_rounded_time + duration) << ".\n"
<< "Queries: " << used.queries << ".\n"
<< "Errors: " << used.errors << ".\n"
<< "Result rows: " << used.result_rows << ".\n"
<< "Result bytes: " << used.result_bytes << ".\n"
<< "Read rows: " << used.read_rows << ".\n"
<< "Read bytes: " << used.read_bytes << ".\n"
<< "Execution time: " << used.execution_time_usec / 1000000.0 << " sec.\n";
return res.str();
}
void QuotaForInterval::addQuery() noexcept
{
++used.queries;
}
void QuotaForInterval::addError() noexcept
{
++used.errors;
}
void QuotaForInterval::checkAndAddResultRowsBytes(time_t current_time, const String & quota_name, const String & user_name, size_t rows, size_t bytes)
{
used.result_rows += rows;
used.result_bytes += bytes;
checkExceeded(current_time, quota_name, user_name);
}
void QuotaForInterval::checkAndAddReadRowsBytes(time_t current_time, const String & quota_name, const String & user_name, size_t rows, size_t bytes)
{
used.read_rows += rows;
used.read_bytes += bytes;
checkExceeded(current_time, quota_name, user_name);
}
void QuotaForInterval::checkAndAddExecutionTime(time_t current_time, const String & quota_name, const String & user_name, Poco::Timespan amount)
{
/// Information about internals of Poco::Timespan used.
used.execution_time_usec += amount.totalMicroseconds();
checkExceeded(current_time, quota_name, user_name);
}
void QuotaForInterval::updateTime(time_t current_time)
{
/** If current time is greater than end of interval,
* then clear accumulated quota values and switch to next interval [rounded_time, rounded_time + duration).
*/
auto loaded_rounded_time = rounded_time.load(std::memory_order_acquire);
while (true)
{
if (current_time < loaded_rounded_time + static_cast<time_t>(duration))
break;
time_t new_rounded_time = (current_time - offset) / duration * duration + offset;
if (rounded_time.compare_exchange_strong(loaded_rounded_time, new_rounded_time))
{
used.clear();
break;
}
}
}
void QuotaForInterval::check(
size_t max_amount, size_t used_amount,
const String & quota_name, const String & user_name, const char * resource_name)
{
if (max_amount && used_amount > max_amount)
{
std::stringstream message;
message << "Quota for user '" << user_name << "' for ";
if (duration == 3600)
message << "1 hour";
else if (duration == 60)
message << "1 minute";
else if (duration % 3600 == 0)
message << (duration / 3600) << " hours";
else if (duration % 60 == 0)
message << (duration / 60) << " minutes";
else
message << duration << " seconds";
message << " has been exceeded. "
<< resource_name << ": " << used_amount << ", max: " << max_amount << ". "
<< "Interval will end at " << LocalDateTime(rounded_time.load(std::memory_order_relaxed) + duration) << ". "
<< "Name of quota template: '" << quota_name << "'.";
throw Exception(message.str(), ErrorCodes::QUOTA_EXPIRED);
}
}
void QuotaForIntervals::initFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config, pcg64 & rng)
{
Poco::Util::AbstractConfiguration::Keys config_keys;
config.keys(config_elem, config_keys);
for (Poco::Util::AbstractConfiguration::Keys::const_iterator it = config_keys.begin(); it != config_keys.end(); ++it)
{
if (!startsWith(*it, "interval"))
continue;
String interval_config_elem = config_elem + "." + *it;
time_t duration = config.getInt(interval_config_elem + ".duration", 0);
time_t offset = 0;
if (!duration) /// Skip quotas with zero duration
continue;
bool randomize = config.getBool(interval_config_elem + ".randomize", false);
if (randomize)
offset = std::uniform_int_distribution<decltype(duration)>(0, duration - 1)(rng);
cont[duration].initFromConfig(interval_config_elem, duration, randomize, offset, config);
}
}
void QuotaForIntervals::setMax(const QuotaForIntervals & quota)
{
for (Container::iterator it = cont.begin(); it != cont.end();)
{
if (quota.cont.count(it->first))
++it;
else
cont.erase(it++);
}
for (auto & x : quota.cont)
{
if (!cont.count(x.first))
cont.emplace(x.first, x.second);
else
cont[x.first].max = x.second.max;
}
}
void QuotaForIntervals::checkExceeded(time_t current_time)
{
for (Container::reverse_iterator it = cont.rbegin(); it != cont.rend(); ++it)
it->second.checkExceeded(current_time, quota_name, user_name);
}
void QuotaForIntervals::addQuery() noexcept
{
for (Container::reverse_iterator it = cont.rbegin(); it != cont.rend(); ++it)
it->second.addQuery();
}
void QuotaForIntervals::addError() noexcept
{
for (Container::reverse_iterator it = cont.rbegin(); it != cont.rend(); ++it)
it->second.addError();
}
void QuotaForIntervals::checkAndAddResultRowsBytes(time_t current_time, size_t rows, size_t bytes)
{
for (Container::reverse_iterator it = cont.rbegin(); it != cont.rend(); ++it)
it->second.checkAndAddResultRowsBytes(current_time, quota_name, user_name, rows, bytes);
}
void QuotaForIntervals::checkAndAddReadRowsBytes(time_t current_time, size_t rows, size_t bytes)
{
for (Container::reverse_iterator it = cont.rbegin(); it != cont.rend(); ++it)
it->second.checkAndAddReadRowsBytes(current_time, quota_name, user_name, rows, bytes);
}
void QuotaForIntervals::checkAndAddExecutionTime(time_t current_time, Poco::Timespan amount)
{
for (Container::reverse_iterator it = cont.rbegin(); it != cont.rend(); ++it)
it->second.checkAndAddExecutionTime(current_time, quota_name, user_name, amount);
}
String QuotaForIntervals::toString() const
{
std::stringstream res;
for (Container::const_reverse_iterator it = cont.rbegin(); it != cont.rend(); ++it)
res << std::endl << it->second.toString();
return res.str();
}
void Quota::loadFromConfig(const String & config_elem, const String & name_, const Poco::Util::AbstractConfiguration & config, pcg64 & rng)
{
name = name_;
bool new_keyed_by_ip = config.has(config_elem + ".keyed_by_ip");
bool new_is_keyed = new_keyed_by_ip || config.has(config_elem + ".keyed");
if (new_is_keyed != is_keyed || new_keyed_by_ip != keyed_by_ip)
{
keyed_by_ip = new_keyed_by_ip;
is_keyed = new_is_keyed;
/// Meaning of keys has been changed. Throw away accumulated values.
quota_for_keys.clear();
}
ignore_key_if_not_keyed = config.has(config_elem + ".ignore_key_if_not_keyed");
QuotaForIntervals new_max(name, {});
new_max.initFromConfig(config_elem, config, rng);
if (!new_max.hasEqualConfiguration(max))
{
max = new_max;
for (auto & quota : quota_for_keys)
quota.second->setMax(max);
}
}
QuotaForIntervalsPtr Quota::get(const String & quota_key, const String & user_name, const Poco::Net::IPAddress & ip)
{
if (!quota_key.empty() && !ignore_key_if_not_keyed && (!is_keyed || keyed_by_ip))
throw Exception("Quota " + name + " (for user " + user_name + ") doesn't allow client supplied keys.",
ErrorCodes::QUOTA_DOESNT_ALLOW_KEYS);
/** Quota is calculated separately:
* - for each IP-address, if 'keyed_by_ip';
* - otherwise for each 'quota_key', if present;
* - otherwise for each 'user_name'.
*/
UInt64 quota_key_hashed = sipHash64(
keyed_by_ip
? ip.toString()
: (!quota_key.empty()
? quota_key
: user_name));
std::lock_guard lock(mutex);
Container::iterator it = quota_for_keys.find(quota_key_hashed);
if (quota_for_keys.end() == it)
it = quota_for_keys.emplace(quota_key_hashed, std::make_shared<QuotaForIntervals>(max, user_name)).first;
return it->second;
}
void Quotas::loadFromConfig(const Poco::Util::AbstractConfiguration & config)
{
pcg64 rng;
Poco::Util::AbstractConfiguration::Keys config_keys;
config.keys("quotas", config_keys);
/// Remove keys, that now absent in config.
std::set<std::string> keys_set(config_keys.begin(), config_keys.end());
for (Container::iterator it = cont.begin(); it != cont.end();)
{
if (keys_set.count(it->first))
++it;
else
cont.erase(it++);
}
for (Poco::Util::AbstractConfiguration::Keys::const_iterator it = config_keys.begin(); it != config_keys.end(); ++it)
{
if (!cont.count(*it))
cont.try_emplace(*it);
cont[*it].loadFromConfig("quotas." + *it, *it, config, rng);
}
}
QuotaForIntervalsPtr Quotas::get(const String & name, const String & quota_key, const String & user_name, const Poco::Net::IPAddress & ip)
{
Container::iterator it = cont.find(name);
if (cont.end() == it)
throw Exception("Unknown quota " + name, ErrorCodes::UNKNOWN_QUOTA);
return it->second.get(quota_key, user_name, ip);
}
}

View File

@ -1,263 +0,0 @@
#pragma once
#include <cstring>
#include <unordered_map>
#include <memory>
#include <pcg_random.hpp>
#include <Poco/Timespan.h>
#include <Poco/Util/Application.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Poco/Net/IPAddress.h>
#include <Core/Types.h>
#include <Common/Exception.h>
#include <IO/WriteHelpers.h>
namespace DB
{
/** Quota for resources consumption for specific interval.
* Used to limit resource usage by user.
* Quota is applied "softly" - could be slightly exceed, because it is checked usually only on each block of processed data.
* Accumulated values are not persisted and are lost on server restart.
* Quota is local to server,
* but for distributed queries, accumulated values for read rows and bytes
* are collected from all participating servers and accumulated locally.
*/
/// Used both for maximum allowed values and for counters of current accumulated values.
template <typename Counter> /// either size_t or std::atomic<size_t>
struct QuotaValues
{
/// Zero values (for maximums) means no limit.
Counter queries; /// Number of queries.
Counter errors; /// Number of queries with exceptions.
Counter result_rows; /// Number of rows returned as result.
Counter result_bytes; /// Number of bytes returned as result.
Counter read_rows; /// Number of rows read from tables.
Counter read_bytes; /// Number of bytes read from tables.
Counter execution_time_usec; /// Total amount of query execution time in microseconds.
QuotaValues()
{
clear();
}
QuotaValues(const QuotaValues & rhs)
{
tuple() = rhs.tuple();
}
QuotaValues & operator=(const QuotaValues & rhs)
{
tuple() = rhs.tuple();
return *this;
}
void clear()
{
tuple() = std::make_tuple(0, 0, 0, 0, 0, 0, 0);
}
void initFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config);
bool operator== (const QuotaValues & rhs) const
{
return tuple() == rhs.tuple();
}
private:
auto tuple()
{
return std::forward_as_tuple(queries, errors, result_rows, result_bytes, read_rows, read_bytes, execution_time_usec);
}
auto tuple() const
{
return std::make_tuple(queries, errors, result_rows, result_bytes, read_rows, read_bytes, execution_time_usec);
}
};
template <>
inline auto QuotaValues<std::atomic<size_t>>::tuple() const
{
return std::make_tuple(
queries.load(std::memory_order_relaxed),
errors.load(std::memory_order_relaxed),
result_rows.load(std::memory_order_relaxed),
result_bytes.load(std::memory_order_relaxed),
read_rows.load(std::memory_order_relaxed),
read_bytes.load(std::memory_order_relaxed),
execution_time_usec.load(std::memory_order_relaxed));
}
/// Time, rounded down to start of interval; limits for that interval and accumulated values.
struct QuotaForInterval
{
std::atomic<time_t> rounded_time {0};
size_t duration = 0;
bool randomize = false;
time_t offset = 0; /// Offset of interval for randomization (to avoid DoS if intervals for many users end at one time).
QuotaValues<size_t> max;
QuotaValues<std::atomic<size_t>> used;
QuotaForInterval() = default;
QuotaForInterval(time_t duration_) : duration(duration_) {}
void initFromConfig(const String & config_elem, time_t duration_, bool randomize_, time_t offset_, const Poco::Util::AbstractConfiguration & config);
/// Increase current value.
void addQuery() noexcept;
void addError() noexcept;
/// Check if quota is already exceeded. If that, throw an exception.
void checkExceeded(time_t current_time, const String & quota_name, const String & user_name);
/// Check corresponding value. If exceeded, throw an exception. Otherwise, increase that value.
void checkAndAddResultRowsBytes(time_t current_time, const String & quota_name, const String & user_name, size_t rows, size_t bytes);
void checkAndAddReadRowsBytes(time_t current_time, const String & quota_name, const String & user_name, size_t rows, size_t bytes);
void checkAndAddExecutionTime(time_t current_time, const String & quota_name, const String & user_name, Poco::Timespan amount);
/// Get a text, describing what quota is exceeded.
String toString() const;
/// Only compare configuration, not accumulated (used) values or random offsets.
bool operator== (const QuotaForInterval & rhs) const
{
return randomize == rhs.randomize
&& duration == rhs.duration
&& max == rhs.max;
}
QuotaForInterval & operator= (const QuotaForInterval & rhs)
{
rounded_time.store(rhs.rounded_time.load(std::memory_order_relaxed));
duration = rhs.duration;
randomize = rhs.randomize;
offset = rhs.offset;
max = rhs.max;
used = rhs.used;
return *this;
}
QuotaForInterval(const QuotaForInterval & rhs)
{
*this = rhs;
}
private:
/// Reset counters of used resources, if interval for quota is expired.
void updateTime(time_t current_time);
void check(size_t max_amount, size_t used_amount,
const String & quota_name, const String & user_name, const char * resource_name);
};
struct Quota;
/// Length of interval -> quota: maximum allowed and currently accumulated values for that interval (example: 3600 -> values for current hour).
class QuotaForIntervals
{
private:
/// While checking, will walk through intervals in order of decreasing size - from largest to smallest.
/// To report first about largest interval on what quota was exceeded.
using Container = std::map<size_t, QuotaForInterval>;
Container cont;
std::string quota_name;
std::string user_name; /// user name is set only for current counters for user, not for object that contain maximum values (limits).
public:
QuotaForIntervals(const std::string & quota_name_, const std::string & user_name_)
: quota_name(quota_name_), user_name(user_name_) {}
QuotaForIntervals(const QuotaForIntervals & other, const std::string & user_name_)
: QuotaForIntervals(other)
{
user_name = user_name_;
}
QuotaForIntervals() = default;
QuotaForIntervals(const QuotaForIntervals &) = default;
QuotaForIntervals & operator=(const QuotaForIntervals &) = default;
/// Is there at least one interval for counting quota?
bool empty() const
{
return cont.empty();
}
void initFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config, pcg64 & rng);
/// Set maximum values (limits) from passed argument.
/// Remove intervals that does not exist in argument. Add intervals from argument, that we don't have.
void setMax(const QuotaForIntervals & quota);
void addQuery() noexcept;
void addError() noexcept;
void checkExceeded(time_t current_time);
void checkAndAddResultRowsBytes(time_t current_time, size_t rows, size_t bytes);
void checkAndAddReadRowsBytes(time_t current_time, size_t rows, size_t bytes);
void checkAndAddExecutionTime(time_t current_time, Poco::Timespan amount);
/// Get text, describing what part of quota has been exceeded.
String toString() const;
bool hasEqualConfiguration(const QuotaForIntervals & rhs) const
{
return cont == rhs.cont && quota_name == rhs.quota_name;
}
};
using QuotaForIntervalsPtr = std::shared_ptr<QuotaForIntervals>;
/// Quota key -> quotas (max and current values) for intervals. If quota doesn't have keys, then values stored at key 0.
struct Quota
{
using Container = std::unordered_map<UInt64, QuotaForIntervalsPtr>;
String name;
/// Maximum values from config.
QuotaForIntervals max;
/// Maximum and accumulated values for different keys.
/// For all keys, maximum values are the same and taken from 'max'.
Container quota_for_keys;
std::mutex mutex;
bool is_keyed = false;
/// If the quota is not keyed, but the user passed some key, ignore it instead of throwing exception.
/// For transitional periods, when you want to enable quota keys
/// - first, enable passing keys from your application, then make quota keyed in ClickHouse users config.
bool ignore_key_if_not_keyed = false;
bool keyed_by_ip = false;
void loadFromConfig(const String & config_elem, const String & name_, const Poco::Util::AbstractConfiguration & config, pcg64 & rng);
QuotaForIntervalsPtr get(const String & quota_key, const String & user_name, const Poco::Net::IPAddress & ip);
};
class Quotas
{
private:
/// Name of quota -> quota.
using Container = std::unordered_map<String, Quota>;
Container cont;
public:
void loadFromConfig(const Poco::Util::AbstractConfiguration & config);
QuotaForIntervalsPtr get(const String & name, const String & quota_key,
const String & user_name, const Poco::Net::IPAddress & ip);
};
}

View File

@ -24,19 +24,16 @@ struct SelectQueryOptions
{
QueryProcessingStage::Enum to_stage;
size_t subquery_depth;
bool only_analyze;
bool modify_inplace;
bool remove_duplicates;
bool ignore_limits;
bool only_analyze = false;
bool modify_inplace = false;
bool remove_duplicates = false;
bool ignore_quota = false;
bool ignore_limits = false;
SelectQueryOptions(QueryProcessingStage::Enum stage = QueryProcessingStage::Complete, size_t depth = 0)
: to_stage(stage)
, subquery_depth(depth)
, only_analyze(false)
, modify_inplace(false)
, remove_duplicates(false)
, ignore_limits(false)
{}
: to_stage(stage), subquery_depth(depth)
{
}
SelectQueryOptions copy() const { return *this; }

View File

@ -540,7 +540,7 @@ void getArrayJoinedColumns(ASTPtr & query, SyntaxAnalyzerResult & result, const
}
}
void setJoinStrictness(ASTSelectQuery & select_query, JoinStrictness join_default_strictness, ASTTableJoin & out_table_join)
void setJoinStrictness(ASTSelectQuery & select_query, JoinStrictness join_default_strictness, bool old_any, ASTTableJoin & out_table_join)
{
const ASTTablesInSelectQueryElement * node = select_query.join();
if (!node)
@ -560,6 +560,9 @@ void setJoinStrictness(ASTSelectQuery & select_query, JoinStrictness join_defaul
DB::ErrorCodes::EXPECTED_ALL_OR_ANY);
}
if (old_any && table_join.strictness == ASTTableJoin::Strictness::Any)
table_join.strictness = ASTTableJoin::Strictness::RightAny;
out_table_join = table_join;
}
@ -628,13 +631,8 @@ void checkJoin(const ASTTablesInSelectQueryElement * join)
const auto & table_join = join->table_join->as<ASTTableJoin &>();
if (table_join.strictness == ASTTableJoin::Strictness::Any)
if (table_join.kind != ASTTableJoin::Kind::Left)
throw Exception("Old ANY INNER|RIGHT|FULL JOINs are disabled by default. Their logic would be changed. "
"Old logic is many-to-one for all kinds of ANY JOINs. It's equil to apply distinct for right table keys. "
"Default bahaviour is reserved for many-to-one LEFT JOIN, one-to-many RIGHT JOIN and one-to-one INNER JOIN. "
"It would be equal to apply distinct for keys to right, left and both tables respectively. "
"Set any_join_distinct_right_table_keys=1 to enable old bahaviour.",
ErrorCodes::NOT_IMPLEMENTED);
if (table_join.kind == ASTTableJoin::Kind::Full)
throw Exception("ANY FULL JOINs are not implemented.", ErrorCodes::NOT_IMPLEMENTED);
}
std::vector<const ASTFunction *> getAggregates(const ASTPtr & query)
@ -958,7 +956,8 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(
/// Push the predicate expression down to the subqueries.
result.rewrite_subqueries = PredicateExpressionsOptimizer(select_query, settings, context).optimize();
setJoinStrictness(*select_query, settings.join_default_strictness, result.analyzed_join->table_join);
setJoinStrictness(*select_query, settings.join_default_strictness, settings.any_join_distinct_right_table_keys,
result.analyzed_join->table_join);
collectJoinedColumns(*result.analyzed_join, *select_query, source_columns_set, result.aliases);
}

View File

@ -49,7 +49,6 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A
}
profile = config.getString(config_elem + ".profile");
quota = config.getString(config_elem + ".quota");
/// Fill list of allowed hosts.
const auto config_networks = config_elem + ".networks";
@ -130,7 +129,9 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A
}
}
}
if (config.has(config_elem + ".allow_quota_management"))
is_quota_management_allowed = config.getBool(config_elem + ".allow_quota_management");
}
}

View File

@ -30,7 +30,6 @@ struct User
Authentication authentication;
String profile;
String quota;
AllowedClientHosts allowed_client_hosts;
@ -48,6 +47,8 @@ struct User
using DatabaseMap = std::unordered_map<std::string /* database */, TableMap /* tables */>;
DatabaseMap table_props;
bool is_quota_management_allowed = false;
User(const String & name_, const String & config_elem, const Poco::Util::AbstractConfiguration & config);
};

View File

@ -24,7 +24,7 @@
#include <Storages/StorageInput.h>
#include <Interpreters/Quota.h>
#include <Access/QuotaContext.h>
#include <Interpreters/InterpreterFactory.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/QueryLog.h>
@ -150,7 +150,7 @@ static void logException(Context & context, QueryLogElement & elem)
static void onExceptionBeforeStart(const String & query_for_logging, Context & context, time_t current_time)
{
/// Exception before the query execution.
context.getQuota().addError();
context.getQuota()->used(Quota::ERRORS, 1, /* check_exceeded = */ false);
const Settings & settings = context.getSettingsRef();
@ -271,11 +271,6 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
/// Check the limits.
checkASTSizeLimits(*ast, settings);
QuotaForIntervals & quota = context.getQuota();
quota.addQuery(); /// NOTE Seems that when new time interval has come, first query is not accounted in number of queries.
quota.checkExceeded(current_time);
/// Put query to process list. But don't put SHOW PROCESSLIST query itself.
ProcessList::EntryPtr process_list_entry;
if (!internal && !ast->as<ASTShowProcesslistQuery>())
@ -313,6 +308,21 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
auto interpreter = InterpreterFactory::get(ast, context, stage);
bool use_processors = settings.experimental_use_processors && allow_processors && interpreter->canExecuteWithProcessors();
QuotaContextPtr quota;
if (!interpreter->ignoreQuota())
{
quota = context.getQuota();
quota->used(Quota::QUERIES, 1);
quota->checkExceeded(Quota::ERRORS);
}
IBlockInputStream::LocalLimits limits;
if (!interpreter->ignoreLimits())
{
limits.mode = IBlockInputStream::LIMITS_CURRENT;
limits.size_limits = SizeLimits(settings.max_result_rows, settings.max_result_bytes, settings.result_overflow_mode);
}
if (use_processors)
pipeline = interpreter->executeWithProcessors();
else
@ -339,17 +349,12 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
/// Hold element of process list till end of query execution.
res.process_list_entry = process_list_entry;
IBlockInputStream::LocalLimits limits;
limits.mode = IBlockInputStream::LIMITS_CURRENT;
limits.size_limits = SizeLimits(settings.max_result_rows, settings.max_result_bytes, settings.result_overflow_mode);
if (use_processors)
{
pipeline.setProgressCallback(context.getProgressCallback());
pipeline.setProcessListElement(context.getProcessListElement());
/// Limits on the result, the quota on the result, and also callback for progress.
/// Limits apply only to the final result.
pipeline.setProgressCallback(context.getProgressCallback());
pipeline.setProcessListElement(context.getProcessListElement());
if (stage == QueryProcessingStage::Complete)
{
pipeline.resize(1);
@ -363,17 +368,18 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
}
else
{
/// Limits on the result, the quota on the result, and also callback for progress.
/// Limits apply only to the final result.
if (res.in)
{
res.in->setProgressCallback(context.getProgressCallback());
res.in->setProcessListElement(context.getProcessListElement());
/// Limits on the result, the quota on the result, and also callback for progress.
/// Limits apply only to the final result.
if (stage == QueryProcessingStage::Complete)
{
res.in->setLimits(limits);
res.in->setQuota(quota);
if (!interpreter->ignoreQuota())
res.in->setQuota(quota);
if (!interpreter->ignoreLimits())
res.in->setLimits(limits);
}
}
@ -484,7 +490,7 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
auto exception_callback = [elem, &context, log_queries] () mutable
{
context.getQuota().addError();
context.getQuota()->used(Quota::ERRORS, 1, /* check_exceeded = */ false);
elem.type = QueryLogElement::EXCEPTION_WHILE_PROCESSING;

View File

@ -12,57 +12,50 @@
namespace DB
{
template <bool fill_right, typename ASTTableJoin::Strictness>
struct MapGetterImpl;
template <ASTTableJoin::Kind kind, typename ASTTableJoin::Strictness>
struct MapGetter;
template <>
struct MapGetterImpl<false, ASTTableJoin::Strictness::Any>
{
using Map = Join::MapsAny;
};
template <> struct MapGetter<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::RightAny> { using Map = Join::MapsOne; };
template <> struct MapGetter<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::RightAny> { using Map = Join::MapsOne; };
template <> struct MapGetter<ASTTableJoin::Kind::Right, ASTTableJoin::Strictness::RightAny> { using Map = Join::MapsOneFlagged; };
template <> struct MapGetter<ASTTableJoin::Kind::Full, ASTTableJoin::Strictness::RightAny> { using Map = Join::MapsOneFlagged; };
template <>
struct MapGetterImpl<true, ASTTableJoin::Strictness::Any>
{
using Map = Join::MapsAnyFull;
};
template <> struct MapGetter<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any> { using Map = Join::MapsOne; };
template <> struct MapGetter<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any> { using Map = Join::MapsOneFlagged; };
template <> struct MapGetter<ASTTableJoin::Kind::Right, ASTTableJoin::Strictness::Any> { using Map = Join::MapsAllFlagged; };
template <> struct MapGetter<ASTTableJoin::Kind::Full, ASTTableJoin::Strictness::Any> { using Map = Join::MapsAllFlagged; };
template <>
struct MapGetterImpl<false, ASTTableJoin::Strictness::All>
{
using Map = Join::MapsAll;
};
template <> struct MapGetter<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All> { using Map = Join::MapsAll; };
template <> struct MapGetter<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All> { using Map = Join::MapsAll; };
template <> struct MapGetter<ASTTableJoin::Kind::Right, ASTTableJoin::Strictness::All> { using Map = Join::MapsAllFlagged; };
template <> struct MapGetter<ASTTableJoin::Kind::Full, ASTTableJoin::Strictness::All> { using Map = Join::MapsAllFlagged; };
template <>
struct MapGetterImpl<true, ASTTableJoin::Strictness::All>
{
using Map = Join::MapsAllFull;
};
/// Only SEMI LEFT and SEMI RIGHT are valid. INNER and FULL are here for templates instantiation.
template <> struct MapGetter<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Semi> { using Map = Join::MapsOne; };
template <> struct MapGetter<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Semi> { using Map = Join::MapsOne; };
template <> struct MapGetter<ASTTableJoin::Kind::Right, ASTTableJoin::Strictness::Semi> { using Map = Join::MapsAllFlagged; };
template <> struct MapGetter<ASTTableJoin::Kind::Full, ASTTableJoin::Strictness::Semi> { using Map = Join::MapsOne; };
template <bool fill_right>
struct MapGetterImpl<fill_right, ASTTableJoin::Strictness::Asof>
/// Only SEMI LEFT and SEMI RIGHT are valid. INNER and FULL are here for templates instantiation.
template <> struct MapGetter<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Anti> { using Map = Join::MapsOne; };
template <> struct MapGetter<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Anti> { using Map = Join::MapsOne; };
template <> struct MapGetter<ASTTableJoin::Kind::Right, ASTTableJoin::Strictness::Anti> { using Map = Join::MapsAllFlagged; };
template <> struct MapGetter<ASTTableJoin::Kind::Full, ASTTableJoin::Strictness::Anti> { using Map = Join::MapsOne; };
template <ASTTableJoin::Kind kind>
struct MapGetter<kind, ASTTableJoin::Strictness::Asof>
{
using Map = Join::MapsAsof;
};
template <ASTTableJoin::Kind KIND>
struct KindTrait
{
// Affects the Adder trait so that when the right part is empty, adding a default value on the left
static constexpr bool fill_left = static_in_v<KIND, ASTTableJoin::Kind::Left, ASTTableJoin::Kind::Full>;
// Affects the Map trait so that a `used` flag is attached to map slots in order to
// generate default values on the right when the left part is empty
static constexpr bool fill_right = static_in_v<KIND, ASTTableJoin::Kind::Right, ASTTableJoin::Kind::Full>;
};
template <ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness>
using Map = typename MapGetterImpl<KindTrait<kind>::fill_right, strictness>::Map;
static constexpr std::array<ASTTableJoin::Strictness, 3> STRICTNESSES = {
static constexpr std::array<ASTTableJoin::Strictness, 6> STRICTNESSES = {
ASTTableJoin::Strictness::RightAny,
ASTTableJoin::Strictness::Any,
ASTTableJoin::Strictness::All,
ASTTableJoin::Strictness::Asof
ASTTableJoin::Strictness::Asof,
ASTTableJoin::Strictness::Semi,
ASTTableJoin::Strictness::Anti,
};
static constexpr std::array<ASTTableJoin::Kind, 4> KINDS = {
@ -81,7 +74,7 @@ inline bool joinDispatchInit(ASTTableJoin::Kind kind, ASTTableJoin::Strictness s
constexpr auto j = ij % STRICTNESSES.size();
if (kind == KINDS[i] && strictness == STRICTNESSES[j])
{
maps = Map<KINDS[i], STRICTNESSES[j]>();
maps = typename MapGetter<KINDS[i], STRICTNESSES[j]>::Map();
return true;
}
return false;
@ -103,7 +96,7 @@ inline bool joinDispatch(ASTTableJoin::Kind kind, ASTTableJoin::Strictness stric
func(
std::integral_constant<ASTTableJoin::Kind, KINDS[i]>(),
std::integral_constant<ASTTableJoin::Strictness, STRICTNESSES[j]>(),
std::get<Map<KINDS[i], STRICTNESSES[j]>>(maps));
std::get<typename MapGetter<KINDS[i], STRICTNESSES[j]>::Map>(maps));
return true;
}
return false;

View File

@ -0,0 +1,142 @@
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Common/quoteString.h>
#include <Common/IntervalKind.h>
#include <ext/range.h>
namespace DB
{
namespace
{
using KeyType = Quota::KeyType;
using ResourceType = Quota::ResourceType;
using ResourceAmount = Quota::ResourceAmount;
void formatKeyType(const KeyType & key_type, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " KEYED BY " << (settings.hilite ? IAST::hilite_none : "") << "'"
<< Quota::getNameOfKeyType(key_type) << "'";
}
void formatRenameTo(const String & new_name, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " RENAME TO " << (settings.hilite ? IAST::hilite_none : "")
<< backQuote(new_name);
}
void formatLimit(ResourceType resource_type, ResourceAmount max, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " MAX " << Quota::resourceTypeToKeyword(resource_type)
<< (settings.hilite ? IAST::hilite_none : "");
settings.ostr << (settings.hilite ? IAST::hilite_operator : "") << " = " << (settings.hilite ? IAST::hilite_none : "");
if (max == Quota::UNLIMITED)
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "ANY" << (settings.hilite ? IAST::hilite_none : "");
else if (resource_type == Quota::EXECUTION_TIME)
settings.ostr << Quota::executionTimeToSeconds(max);
else
settings.ostr << max;
}
void formatLimits(const ASTCreateQuotaQuery::Limits & limits, const IAST::FormatSettings & settings)
{
auto interval_kind = IntervalKind::fromAvgSeconds(limits.duration.count());
Int64 num_intervals = limits.duration.count() / interval_kind.toAvgSeconds();
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "")
<< " FOR"
<< (limits.randomize_interval ? " RANDOMIZED" : "")
<< " INTERVAL "
<< (settings.hilite ? IAST::hilite_none : "")
<< num_intervals << " "
<< (settings.hilite ? IAST::hilite_keyword : "")
<< interval_kind.toKeyword()
<< (settings.hilite ? IAST::hilite_none : "");
if (limits.unset_tracking)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " UNSET TRACKING" << (settings.hilite ? IAST::hilite_none : "");
}
else
{
bool limit_found = false;
for (auto resource_type : ext::range_with_static_cast<ResourceType>(Quota::MAX_RESOURCE_TYPE))
{
if (limits.max[resource_type])
{
if (limit_found)
settings.ostr << ",";
limit_found = true;
formatLimit(resource_type, *limits.max[resource_type], settings);
}
}
if (!limit_found)
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TRACKING" << (settings.hilite ? IAST::hilite_none : "");
}
}
void formatAllLimits(const std::vector<ASTCreateQuotaQuery::Limits> & all_limits, const IAST::FormatSettings & settings)
{
bool need_comma = false;
for (auto & limits : all_limits)
{
if (need_comma)
settings.ostr << ",";
need_comma = true;
formatLimits(limits, settings);
}
}
void formatRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : "");
roles.format(settings);
}
}
String ASTCreateQuotaQuery::getID(char) const
{
return "CreateQuotaQuery";
}
ASTPtr ASTCreateQuotaQuery::clone() const
{
return std::make_shared<ASTCreateQuotaQuery>(*this);
}
void ASTCreateQuotaQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << (alter ? "ALTER QUOTA" : "CREATE QUOTA")
<< (settings.hilite ? hilite_none : "");
if (if_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : "");
else if (if_not_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (settings.hilite ? hilite_none : "");
else if (or_replace)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " OR REPLACE" << (settings.hilite ? hilite_none : "");
settings.ostr << " " << backQuoteIfNeed(name);
if (!new_name.empty())
formatRenameTo(new_name, settings);
if (key_type)
formatKeyType(*key_type, settings);
formatAllLimits(all_limits, settings);
if (roles)
formatRoles(*roles, settings);
}
}

View File

@ -0,0 +1,62 @@
#pragma once
#include <Parsers/IAST.h>
#include <Access/Quota.h>
namespace DB
{
class ASTRoleList;
/** CREATE QUOTA [IF NOT EXISTS | OR REPLACE] name
* [KEYED BY {'none' | 'user name' | 'ip address' | 'client key' | 'client key or user name' | 'client key or ip address'}]
* [FOR [RANDOMIZED] INTERVAL number {SECOND | MINUTE | HOUR | DAY}
* {[SET] MAX {{QUERIES | ERRORS | RESULT ROWS | RESULT BYTES | READ ROWS | READ BYTES | EXECUTION TIME} = {number | ANY} } [,...] |
* [SET] TRACKING} [,...]]
* [TO {role [,...] | ALL | ALL EXCEPT role [,...]}]
*
* ALTER QUOTA [IF EXISTS] name
* [RENAME TO new_name]
* [KEYED BY {'none' | 'user name' | 'ip address' | 'client key' | 'client key or user name' | 'client key or ip address'}]
* [FOR [RANDOMIZED] INTERVAL number {SECOND | MINUTE | HOUR | DAY}
* {[SET] MAX {{QUERIES | ERRORS | RESULT ROWS | RESULT BYTES | READ ROWS | READ BYTES | EXECUTION TIME} = {number | ANY} } [,...] |
* [SET] TRACKING |
* UNSET TRACKING} [,...]]
* [TO {role [,...] | ALL | ALL EXCEPT role [,...]}]
*/
class ASTCreateQuotaQuery : public IAST
{
public:
bool alter = false;
bool if_exists = false;
bool if_not_exists = false;
bool or_replace = false;
String name;
String new_name;
using KeyType = Quota::KeyType;
std::optional<KeyType> key_type;
using ResourceType = Quota::ResourceType;
using ResourceAmount = Quota::ResourceAmount;
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
struct Limits
{
std::optional<ResourceAmount> max[MAX_RESOURCE_TYPE];
bool unset_tracking = false;
std::chrono::seconds duration = std::chrono::seconds::zero();
bool randomize_interval = false;
};
std::vector<Limits> all_limits;
std::shared_ptr<ASTRoleList> roles;
String getID(char) const override;
ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -0,0 +1,56 @@
#include <Parsers/ASTDropAccessEntityQuery.h>
#include <Common/quoteString.h>
namespace DB
{
namespace
{
using Kind = ASTDropAccessEntityQuery::Kind;
const char * kindToKeyword(Kind kind)
{
switch (kind)
{
case Kind::QUOTA: return "QUOTA";
}
__builtin_unreachable();
}
}
ASTDropAccessEntityQuery::ASTDropAccessEntityQuery(Kind kind_)
: kind(kind_), keyword(kindToKeyword(kind_))
{
}
String ASTDropAccessEntityQuery::getID(char) const
{
return String("DROP ") + keyword + " query";
}
ASTPtr ASTDropAccessEntityQuery::clone() const
{
return std::make_shared<ASTDropAccessEntityQuery>(*this);
}
void ASTDropAccessEntityQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "")
<< "DROP " << keyword
<< (if_exists ? " IF EXISTS" : "")
<< (settings.hilite ? hilite_none : "");
bool need_comma = false;
for (const auto & name : names)
{
if (need_comma)
settings.ostr << ',';
need_comma = true;
settings.ostr << ' ' << backQuoteIfNeed(name);
}
}
}

View File

@ -0,0 +1,28 @@
#pragma once
#include <Parsers/IAST.h>
namespace DB
{
/** DROP QUOTA [IF EXISTS] name [,...]
*/
class ASTDropAccessEntityQuery : public IAST
{
public:
enum class Kind
{
QUOTA,
};
const Kind kind;
const char * const keyword;
bool if_exists = false;
Strings names;
ASTDropAccessEntityQuery(Kind kind_);
String getID(char) const override;
ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

Some files were not shown because too many files have changed in this diff Show More