Pass Settings to aggregate function creator

This commit is contained in:
vdimir 2021-04-07 16:46:57 +03:00
parent ad85467128
commit ebc846b9f8
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
3 changed files with 36 additions and 7 deletions

View File

@ -35,9 +35,25 @@ const String & getAggregateFunctionCanonicalNameIfAny(const String & name)
return AggregateFunctionFactory::instance().getCanonicalNameIfAny(name);
}
bool AggregateFunctionWithProperties::hasCreator() const
{
return std::visit([](auto func) { return func != nullptr; }, creator);
}
AggregateFunctionPtr
AggregateFunctionWithProperties::create(String name, const DataTypes & argument_types, const Array & params, const Settings & settings) const
{
if (std::holds_alternative<AggregateFunctionCreator>(creator))
return std::get<AggregateFunctionCreator>(creator)(name, argument_types, params);
if (std::holds_alternative<AggregateFunctionCreatorWithSettings>(creator))
return std::get<AggregateFunctionCreatorWithSettings>(creator)(name, argument_types, params, settings);
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unhandled aggregate function creator type");
}
void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
{
if (creator_with_properties.creator == nullptr)
if (!creator_with_properties.hasCreator())
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
" a null constructor", ErrorCodes::LOGICAL_ERROR);
@ -125,7 +141,7 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
if (CurrentThread::isInitialized())
query_context = CurrentThread::get().getQueryContext();
if (found.creator)
if (found.hasCreator())
{
out_properties = found.properties;
@ -137,7 +153,7 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
if (!out_properties.returns_default_when_only_null && has_null_arguments)
return nullptr;
return found.creator(name, argument_types, parameters);
return found.create(name, argument_types, parameters, query_context->getSettingsRef());
}
/// Combinators of aggregate functions.
@ -197,7 +213,7 @@ std::optional<AggregateFunctionProperties> AggregateFunctionFactory::tryGetPrope
if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
found = jt->second;
if (found.creator)
if (found.hasCreator())
return found.properties;
/// Combinators of aggregate functions.

View File

@ -10,12 +10,14 @@
#include <unordered_map>
#include <vector>
#include <optional>
#include <variant>
namespace DB
{
class Context;
struct Settings;
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
@ -27,10 +29,12 @@ using DataTypes = std::vector<DataTypePtr>;
* For example, in quantileWeighted(0.9)(x, weight), 0.9 is "parameter" and x, weight are "arguments".
*/
using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>;
using AggregateFunctionCreatorWithSettings
= std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &, const Settings &)>;
struct AggregateFunctionWithProperties
{
AggregateFunctionCreator creator;
std::variant<AggregateFunctionCreator, AggregateFunctionCreatorWithSettings> creator;
AggregateFunctionProperties properties;
AggregateFunctionWithProperties() = default;
@ -42,6 +46,9 @@ struct AggregateFunctionWithProperties
: creator(std::forward<Creator>(creator_)), properties(std::move(properties_))
{
}
bool hasCreator() const;
AggregateFunctionPtr create(String name, const DataTypes & argument_types, const Array & params, const Settings & settings) const;
};

View File

@ -131,8 +131,14 @@ namespace
void registerAggregateFunctionUniqCombined(AggregateFunctionFactory & factory)
{
using namespace std::placeholders;
factory.registerFunction("uniqCombined", std::bind(createAggregateFunctionUniqCombined, false, _1, _2, _3)); // NOLINT
factory.registerFunction("uniqCombined64", std::bind(createAggregateFunctionUniqCombined, true, _1, _2, _3)); // NOLINT
factory.registerFunction("uniqCombined", [](const std::string & name, const DataTypes & argument_types, const Array & params)
{
return createAggregateFunctionUniqCombined(false, name, argument_types, params);
});
factory.registerFunction("uniqCombined64", [](const std::string & name, const DataTypes & argument_types, const Array & params)
{
return createAggregateFunctionUniqCombined(true, name, argument_types, params);
});
}
}