Pass Settings to aggregate function creator

This commit is contained in:
vdimir 2021-04-07 16:46:57 +03:00
parent ad85467128
commit ebc846b9f8
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
3 changed files with 36 additions and 7 deletions

View File

@ -35,9 +35,25 @@ const String & getAggregateFunctionCanonicalNameIfAny(const String & name)
return AggregateFunctionFactory::instance().getCanonicalNameIfAny(name); return AggregateFunctionFactory::instance().getCanonicalNameIfAny(name);
} }
bool AggregateFunctionWithProperties::hasCreator() const
{
return std::visit([](auto func) { return func != nullptr; }, creator);
}
AggregateFunctionPtr
AggregateFunctionWithProperties::create(String name, const DataTypes & argument_types, const Array & params, const Settings & settings) const
{
if (std::holds_alternative<AggregateFunctionCreator>(creator))
return std::get<AggregateFunctionCreator>(creator)(name, argument_types, params);
if (std::holds_alternative<AggregateFunctionCreatorWithSettings>(creator))
return std::get<AggregateFunctionCreatorWithSettings>(creator)(name, argument_types, params, settings);
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unhandled aggregate function creator type");
}
void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness) void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
{ {
if (creator_with_properties.creator == nullptr) if (!creator_with_properties.hasCreator())
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided " throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
" a null constructor", ErrorCodes::LOGICAL_ERROR); " a null constructor", ErrorCodes::LOGICAL_ERROR);
@ -125,7 +141,7 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
if (CurrentThread::isInitialized()) if (CurrentThread::isInitialized())
query_context = CurrentThread::get().getQueryContext(); query_context = CurrentThread::get().getQueryContext();
if (found.creator) if (found.hasCreator())
{ {
out_properties = found.properties; out_properties = found.properties;
@ -137,7 +153,7 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
if (!out_properties.returns_default_when_only_null && has_null_arguments) if (!out_properties.returns_default_when_only_null && has_null_arguments)
return nullptr; return nullptr;
return found.creator(name, argument_types, parameters); return found.create(name, argument_types, parameters, query_context->getSettingsRef());
} }
/// Combinators of aggregate functions. /// Combinators of aggregate functions.
@ -197,7 +213,7 @@ std::optional<AggregateFunctionProperties> AggregateFunctionFactory::tryGetPrope
if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end()) if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
found = jt->second; found = jt->second;
if (found.creator) if (found.hasCreator())
return found.properties; return found.properties;
/// Combinators of aggregate functions. /// Combinators of aggregate functions.

View File

@ -10,12 +10,14 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <optional> #include <optional>
#include <variant>
namespace DB namespace DB
{ {
class Context; class Context;
struct Settings;
class IDataType; class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>; using DataTypePtr = std::shared_ptr<const IDataType>;
@ -27,10 +29,12 @@ using DataTypes = std::vector<DataTypePtr>;
* For example, in quantileWeighted(0.9)(x, weight), 0.9 is "parameter" and x, weight are "arguments". * For example, in quantileWeighted(0.9)(x, weight), 0.9 is "parameter" and x, weight are "arguments".
*/ */
using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>; using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>;
using AggregateFunctionCreatorWithSettings
= std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &, const Settings &)>;
struct AggregateFunctionWithProperties struct AggregateFunctionWithProperties
{ {
AggregateFunctionCreator creator; std::variant<AggregateFunctionCreator, AggregateFunctionCreatorWithSettings> creator;
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
AggregateFunctionWithProperties() = default; AggregateFunctionWithProperties() = default;
@ -42,6 +46,9 @@ struct AggregateFunctionWithProperties
: creator(std::forward<Creator>(creator_)), properties(std::move(properties_)) : creator(std::forward<Creator>(creator_)), properties(std::move(properties_))
{ {
} }
bool hasCreator() const;
AggregateFunctionPtr create(String name, const DataTypes & argument_types, const Array & params, const Settings & settings) const;
}; };

View File

@ -131,8 +131,14 @@ namespace
void registerAggregateFunctionUniqCombined(AggregateFunctionFactory & factory) void registerAggregateFunctionUniqCombined(AggregateFunctionFactory & factory)
{ {
using namespace std::placeholders; using namespace std::placeholders;
factory.registerFunction("uniqCombined", std::bind(createAggregateFunctionUniqCombined, false, _1, _2, _3)); // NOLINT factory.registerFunction("uniqCombined", [](const std::string & name, const DataTypes & argument_types, const Array & params)
factory.registerFunction("uniqCombined64", std::bind(createAggregateFunctionUniqCombined, true, _1, _2, _3)); // NOLINT {
return createAggregateFunctionUniqCombined(false, name, argument_types, params);
});
factory.registerFunction("uniqCombined64", [](const std::string & name, const DataTypes & argument_types, const Array & params)
{
return createAggregateFunctionUniqCombined(true, name, argument_types, params);
});
} }
} }