mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-02 20:42:04 +00:00
Pass Settings to aggregate function creator
This commit is contained in:
parent
ad85467128
commit
ebc846b9f8
@ -35,9 +35,25 @@ const String & getAggregateFunctionCanonicalNameIfAny(const String & name)
|
||||
return AggregateFunctionFactory::instance().getCanonicalNameIfAny(name);
|
||||
}
|
||||
|
||||
|
||||
bool AggregateFunctionWithProperties::hasCreator() const
|
||||
{
|
||||
return std::visit([](auto func) { return func != nullptr; }, creator);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr
|
||||
AggregateFunctionWithProperties::create(String name, const DataTypes & argument_types, const Array & params, const Settings & settings) const
|
||||
{
|
||||
if (std::holds_alternative<AggregateFunctionCreator>(creator))
|
||||
return std::get<AggregateFunctionCreator>(creator)(name, argument_types, params);
|
||||
if (std::holds_alternative<AggregateFunctionCreatorWithSettings>(creator))
|
||||
return std::get<AggregateFunctionCreatorWithSettings>(creator)(name, argument_types, params, settings);
|
||||
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unhandled aggregate function creator type");
|
||||
}
|
||||
|
||||
void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
|
||||
{
|
||||
if (creator_with_properties.creator == nullptr)
|
||||
if (!creator_with_properties.hasCreator())
|
||||
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
|
||||
" a null constructor", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
@ -125,7 +141,7 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
|
||||
if (CurrentThread::isInitialized())
|
||||
query_context = CurrentThread::get().getQueryContext();
|
||||
|
||||
if (found.creator)
|
||||
if (found.hasCreator())
|
||||
{
|
||||
out_properties = found.properties;
|
||||
|
||||
@ -137,7 +153,7 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
|
||||
if (!out_properties.returns_default_when_only_null && has_null_arguments)
|
||||
return nullptr;
|
||||
|
||||
return found.creator(name, argument_types, parameters);
|
||||
return found.create(name, argument_types, parameters, query_context->getSettingsRef());
|
||||
}
|
||||
|
||||
/// Combinators of aggregate functions.
|
||||
@ -197,7 +213,7 @@ std::optional<AggregateFunctionProperties> AggregateFunctionFactory::tryGetPrope
|
||||
if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
|
||||
found = jt->second;
|
||||
|
||||
if (found.creator)
|
||||
if (found.hasCreator())
|
||||
return found.properties;
|
||||
|
||||
/// Combinators of aggregate functions.
|
||||
|
@ -10,12 +10,14 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class Context;
|
||||
struct Settings;
|
||||
class IDataType;
|
||||
|
||||
using DataTypePtr = std::shared_ptr<const IDataType>;
|
||||
@ -27,10 +29,12 @@ using DataTypes = std::vector<DataTypePtr>;
|
||||
* For example, in quantileWeighted(0.9)(x, weight), 0.9 is "parameter" and x, weight are "arguments".
|
||||
*/
|
||||
using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>;
|
||||
using AggregateFunctionCreatorWithSettings
|
||||
= std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &, const Settings &)>;
|
||||
|
||||
struct AggregateFunctionWithProperties
|
||||
{
|
||||
AggregateFunctionCreator creator;
|
||||
std::variant<AggregateFunctionCreator, AggregateFunctionCreatorWithSettings> creator;
|
||||
AggregateFunctionProperties properties;
|
||||
|
||||
AggregateFunctionWithProperties() = default;
|
||||
@ -42,6 +46,9 @@ struct AggregateFunctionWithProperties
|
||||
: creator(std::forward<Creator>(creator_)), properties(std::move(properties_))
|
||||
{
|
||||
}
|
||||
|
||||
bool hasCreator() const;
|
||||
AggregateFunctionPtr create(String name, const DataTypes & argument_types, const Array & params, const Settings & settings) const;
|
||||
};
|
||||
|
||||
|
||||
|
@ -131,8 +131,14 @@ namespace
|
||||
void registerAggregateFunctionUniqCombined(AggregateFunctionFactory & factory)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
factory.registerFunction("uniqCombined", std::bind(createAggregateFunctionUniqCombined, false, _1, _2, _3)); // NOLINT
|
||||
factory.registerFunction("uniqCombined64", std::bind(createAggregateFunctionUniqCombined, true, _1, _2, _3)); // NOLINT
|
||||
factory.registerFunction("uniqCombined", [](const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
return createAggregateFunctionUniqCombined(false, name, argument_types, params);
|
||||
});
|
||||
factory.registerFunction("uniqCombined64", [](const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
return createAggregateFunctionUniqCombined(true, name, argument_types, params);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user