Better way of implementation

This commit is contained in:
Alexey Milovidov 2020-06-14 10:44:02 +03:00
parent d990b98b90
commit 394fb64a9c
43 changed files with 196 additions and 136 deletions

View File

@ -14,6 +14,7 @@ set (CLICKHOUSE_ODBC_BRIDGE_SOURCES
set (CLICKHOUSE_ODBC_BRIDGE_LINK set (CLICKHOUSE_ODBC_BRIDGE_LINK
PRIVATE PRIVATE
clickhouse_parsers clickhouse_parsers
clickhouse_aggregate_functions
daemon daemon
dbms dbms
Poco::Data Poco::Data

View File

@ -36,7 +36,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
return std::make_shared<AggregateFunctionArray>(nested_function, arguments); return std::make_shared<AggregateFunctionArray>(nested_function, arguments);
} }

View File

@ -7,6 +7,12 @@
namespace DB namespace DB
{ {
AggregateFunctionPtr AggregateFunctionCount::getOwnNullAdapter(
const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const
{
return std::make_shared<AggregateFunctionCountNotNullUnary>(types[0], params);
}
namespace namespace
{ {
@ -22,7 +28,7 @@ AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, cons
void registerAggregateFunctionCount(AggregateFunctionFactory & factory) void registerAggregateFunctionCount(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("count", createAggregateFunctionCount, AggregateFunctionFactory::CaseInsensitive); factory.registerFunction("count", {createAggregateFunctionCount, {true}}, AggregateFunctionFactory::CaseInsensitive);
} }
} }

View File

@ -68,16 +68,14 @@ public:
data(place).count = new_count; data(place).count = new_count;
} }
/// The function returns non-Nullable type even when wrapped with Null combinator. AggregateFunctionPtr getOwnNullAdapter(
bool returnDefaultWhenOnlyNull() const override const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const override;
{
return true;
}
}; };
/// Simply count number of not-NULL values. /// Simply count number of not-NULL values.
class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary> class AggregateFunctionCountNotNullUnary final
: public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>
{ {
public: public:
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params) AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)

View File

@ -29,18 +29,18 @@ namespace ErrorCodes
} }
void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness) void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
{ {
if (creator == nullptr) if (creator_with_properties.creator == nullptr)
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided " throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
" a null constructor", ErrorCodes::LOGICAL_ERROR); " a null constructor", ErrorCodes::LOGICAL_ERROR);
if (!aggregate_functions.emplace(name, creator).second) if (!aggregate_functions.emplace(name, creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique", throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR); ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive if (case_sensitiveness == CaseInsensitive
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second) && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique", throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR); ErrorCodes::LOGICAL_ERROR);
} }
@ -59,6 +59,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters, const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level) const int recursion_level) const
{ {
auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types); auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
@ -76,18 +77,11 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality);
Array nested_parameters = combinator->transformParameters(parameters); Array nested_parameters = combinator->transformParameters(parameters);
AggregateFunctionPtr nested_function; AggregateFunctionPtr nested_function = getImpl(name, nested_types, nested_parameters, out_properties, recursion_level);
return combinator->transformAggregateFunction(nested_function, out_properties, type_without_low_cardinality, parameters);
/// A little hack - if we have NULL arguments, don't even create nested function.
/// Combinator will check if nested_function was created.
if (name == "count" || std::none_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
[](const auto & type) { return type->onlyNull(); }))
nested_function = getImpl(name, nested_types, nested_parameters, recursion_level);
return combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters);
} }
auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level); auto res = getImpl(name, type_without_low_cardinality, parameters, out_properties, recursion_level);
if (!res) if (!res)
throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR);
return res; return res;
@ -98,19 +92,37 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
const String & name_param, const String & name_param,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters, const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level) const int recursion_level) const
{ {
String name = getAliasToOrName(name_param); String name = getAliasToOrName(name_param);
Value found;
/// Find by exact match. /// Find by exact match.
if (auto it = aggregate_functions.find(name); it != aggregate_functions.end()) if (auto it = aggregate_functions.find(name); it != aggregate_functions.end())
return it->second(name, argument_types, parameters); {
found = it->second;
}
/// Find by case-insensitive name. /// Find by case-insensitive name.
/// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names. /// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names.
if (recursion_level == 0) else if (recursion_level == 0)
{ {
if (auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); it != case_insensitive_aggregate_functions.end()) if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
return it->second(name, argument_types, parameters); found = jt->second;
}
if (found.creator)
{
out_properties = found.properties;
/// The case when aggregate function should return NULL on NULL arguments. This case is handled in "get" method.
if (!out_properties.returns_default_when_only_null
&& std::any_of(argument_types.begin(), argument_types.end(), [](const auto & type) { return type->onlyNull(); }))
{
return nullptr;
}
return found.creator(name, argument_types, parameters);
} }
/// Combinators of aggregate functions. /// Combinators of aggregate functions.
@ -126,9 +138,8 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
DataTypes nested_types = combinator->transformArguments(argument_types); DataTypes nested_types = combinator->transformArguments(argument_types);
Array nested_parameters = combinator->transformParameters(parameters); Array nested_parameters = combinator->transformParameters(parameters);
AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1); AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, out_properties, recursion_level + 1);
return combinator->transformAggregateFunction(nested_function, out_properties, argument_types, parameters);
return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
} }
auto hints = this->getHints(name); auto hints = this->getHints(name);
@ -140,10 +151,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
} }
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const AggregateFunctionPtr AggregateFunctionFactory::tryGet(
const String & name, const DataTypes & argument_types, const Array & parameters, AggregateFunctionProperties & out_properties) const
{ {
return isAggregateFunctionName(name) return isAggregateFunctionName(name)
? get(name, argument_types, parameters) ? get(name, argument_types, parameters, out_properties)
: nullptr; : nullptr;
} }

View File

@ -26,34 +26,50 @@ using DataTypes = std::vector<DataTypePtr>;
*/ */
using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>; using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>;
struct AggregateFunctionWithProperties
{
AggregateFunctionCreator creator;
AggregateFunctionProperties properties;
AggregateFunctionWithProperties() = default;
AggregateFunctionWithProperties(const AggregateFunctionWithProperties &) = default;
template <typename Creator, std::enable_if_t<!std::is_same_v<Creator, AggregateFunctionWithProperties>> * = nullptr>
AggregateFunctionWithProperties(Creator creator_, AggregateFunctionProperties properties_ = {})
: creator(std::forward<Creator>(creator_)), properties(std::move(properties_))
{
}
};
/** Creates an aggregate function by name. /** Creates an aggregate function by name.
*/ */
class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases<AggregateFunctionCreator> class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases<AggregateFunctionWithProperties>
{ {
public: public:
static AggregateFunctionFactory & instance(); static AggregateFunctionFactory & instance();
/// Register a function by its name. /// Register a function by its name.
/// No locking, you must register all functions before usage of get. /// No locking, you must register all functions before usage of get.
void registerFunction( void registerFunction(
const String & name, const String & name,
Creator creator, Value creator,
CaseSensitiveness case_sensitiveness = CaseSensitive); CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Throws an exception if not found. /// Throws an exception if not found.
AggregateFunctionPtr get( AggregateFunctionPtr get(
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters = {}, const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level = 0) const; int recursion_level = 0) const;
/// Returns nullptr if not found. /// Returns nullptr if not found.
AggregateFunctionPtr tryGet( AggregateFunctionPtr tryGet(
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters = {}) const; const Array & parameters,
AggregateFunctionProperties & out_properties) const;
bool isAggregateFunctionName(const String & name, int recursion_level = 0) const; bool isAggregateFunctionName(const String & name, int recursion_level = 0) const;
@ -62,19 +78,20 @@ private:
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters, const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level) const; int recursion_level) const;
private: private:
using AggregateFunctions = std::unordered_map<String, Creator>; using AggregateFunctions = std::unordered_map<String, Value>;
AggregateFunctions aggregate_functions; AggregateFunctions aggregate_functions;
/// Case insensitive aggregate functions will be additionally added here with lowercased name. /// Case insensitive aggregate functions will be additionally added here with lowercased name.
AggregateFunctions case_insensitive_aggregate_functions; AggregateFunctions case_insensitive_aggregate_functions;
const AggregateFunctions & getCreatorMap() const override { return aggregate_functions; } const AggregateFunctions & getMap() const override { return aggregate_functions; }
const AggregateFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_aggregate_functions; } const AggregateFunctions & getCaseInsensitiveMap() const override { return case_insensitive_aggregate_functions; }
String getFactoryName() const override { return "AggregateFunctionFactory"; } String getFactoryName() const override { return "AggregateFunctionFactory"; }

View File

@ -33,7 +33,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
return std::make_shared<AggregateFunctionForEach>(nested_function, arguments); return std::make_shared<AggregateFunctionForEach>(nested_function, arguments);
} }

View File

@ -31,7 +31,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
return std::make_shared<AggregateFunctionIf>(nested_function, arguments); return std::make_shared<AggregateFunctionIf>(nested_function, arguments);
} }

View File

@ -34,7 +34,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
const DataTypePtr & argument = arguments[0]; const DataTypePtr & argument = arguments[0];

View File

@ -25,7 +25,7 @@ public:
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
{ {
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()); return argument_types.front();
} }
void create(AggregateDataPtr) const override void create(AggregateDataPtr) const override

View File

@ -31,13 +31,11 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties & properties,
const DataTypes & arguments,
const Array & params) const override
{ {
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
bool has_nullable_types = false; bool has_nullable_types = false;
bool has_null_types = false; bool has_null_types = false;
for (const auto & arg_type : arguments) for (const auto & arg_type : arguments)
@ -58,15 +56,23 @@ public:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (has_null_types) if (has_null_types)
{
std::cerr << properties.returns_default_when_only_null << "\n";
/// Currently the only functions that returns not-NULL on all NULL arguments are count and uniq, and they returns UInt64.
if (properties.returns_default_when_only_null)
return std::make_shared<AggregateFunctionNothing>(DataTypes{std::make_shared<DataTypeUInt64>()}, params);
else
return std::make_shared<AggregateFunctionNothing>(arguments, params); return std::make_shared<AggregateFunctionNothing>(arguments, params);
}
assert(nested_function); assert(nested_function);
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
return adapter; return adapter;
bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable(); bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull(); bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
if (arguments.size() == 1) if (arguments.size() == 1)
{ {

View File

@ -21,6 +21,7 @@ public:
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments, const DataTypes & arguments,
const Array & params) const override const Array & params) const override
{ {

View File

@ -43,6 +43,7 @@ public:
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments, const DataTypes & arguments,
const Array & params) const override const Array & params) const override
{ {

View File

@ -24,7 +24,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{ {
return std::make_shared<AggregateFunctionState>(nested_function, arguments, params); return std::make_shared<AggregateFunctionState>(nested_function, arguments, params);
} }

View File

@ -123,13 +123,13 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory) void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("uniq", factory.registerFunction("uniq",
createAggregateFunctionUniq<AggregateFunctionUniqUniquesHashSetData, AggregateFunctionUniqUniquesHashSetDataForVariadic>); {createAggregateFunctionUniq<AggregateFunctionUniqUniquesHashSetData, AggregateFunctionUniqUniquesHashSetDataForVariadic>, {true}});
factory.registerFunction("uniqHLL12", factory.registerFunction("uniqHLL12",
createAggregateFunctionUniq<false, AggregateFunctionUniqHLL12Data, AggregateFunctionUniqHLL12DataForVariadic>); {createAggregateFunctionUniq<false, AggregateFunctionUniqHLL12Data, AggregateFunctionUniqHLL12DataForVariadic>, {true}});
factory.registerFunction("uniqExact", factory.registerFunction("uniqExact",
createAggregateFunctionUniq<true, AggregateFunctionUniqExactData, AggregateFunctionUniqExactData<String>>); {createAggregateFunctionUniq<true, AggregateFunctionUniqExactData, AggregateFunctionUniqExactData<String>>, {true}});
} }
} }

View File

@ -244,12 +244,6 @@ public:
{ {
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
}; };
@ -304,12 +298,6 @@ public:
{ {
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
}; };
} }

View File

@ -85,7 +85,7 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c
void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory) void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("uniqUpTo", createAggregateFunctionUniqUpTo); factory.registerFunction("uniqUpTo", {createAggregateFunctionUniqUpTo, {true}});
} }
} }

View File

@ -166,17 +166,12 @@ public:
* nested_function is a smart pointer to this aggregate function itself. * nested_function is a smart pointer to this aggregate function itself.
* arguments and params are for nested_function. * arguments and params are for nested_function.
*/ */
virtual AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const virtual AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const
{ {
return nullptr; return nullptr;
} }
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
virtual bool returnDefaultWhenOnlyNull() const { return false; }
const DataTypes & getArgumentTypes() const { return argument_types; } const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; } const Array & getParameters() const { return parameters; }
@ -286,4 +281,15 @@ public:
}; };
/// Properties of aggregate function that are independent of argument types and parameters.
struct AggregateFunctionProperties
{
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
bool returns_default_when_only_null = false;
};
} }

View File

@ -59,6 +59,7 @@ public:
*/ */
virtual AggregateFunctionPtr transformAggregateFunction( virtual AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties & properties,
const DataTypes & arguments, const DataTypes & arguments,
const Array & params) const = 0; const Array & params) const = 0;

View File

@ -381,6 +381,6 @@ if (ENABLE_TESTS AND USE_GTEST)
-Wno-gnu-zero-variadic-macro-arguments -Wno-gnu-zero-variadic-macro-arguments
) )
target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils) target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_aggregate_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils)
add_check(unit_tests_dbms) add_check(unit_tests_dbms)
endif () endif ()

View File

@ -17,13 +17,13 @@ namespace ErrorCodes
/** If stored objects may have several names (aliases) /** If stored objects may have several names (aliases)
* this interface may be helpful * this interface may be helpful
* template parameter is available as Creator * template parameter is available as Value
*/ */
template <typename CreatorFunc> template <typename ValueType>
class IFactoryWithAliases : public IHints<2, IFactoryWithAliases<CreatorFunc>> class IFactoryWithAliases : public IHints<2, IFactoryWithAliases<ValueType>>
{ {
protected: protected:
using Creator = CreatorFunc; using Value = ValueType;
String getAliasToOrName(const String & name) const String getAliasToOrName(const String & name) const
{ {
@ -43,13 +43,13 @@ public:
CaseInsensitive CaseInsensitive
}; };
/** Register additional name for creator /** Register additional name for value
* real_name have to be already registered. * real_name have to be already registered.
*/ */
void registerAlias(const String & alias_name, const String & real_name, CaseSensitiveness case_sensitiveness = CaseSensitive) void registerAlias(const String & alias_name, const String & real_name, CaseSensitiveness case_sensitiveness = CaseSensitive)
{ {
const auto & creator_map = getCreatorMap(); const auto & creator_map = getMap();
const auto & case_insensitive_creator_map = getCaseInsensitiveCreatorMap(); const auto & case_insensitive_creator_map = getCaseInsensitiveMap();
const String factory_name = getFactoryName(); const String factory_name = getFactoryName();
String real_dict_name; String real_dict_name;
@ -80,7 +80,7 @@ public:
{ {
std::vector<String> result; std::vector<String> result;
auto getter = [](const auto & pair) { return pair.first; }; auto getter = [](const auto & pair) { return pair.first; };
std::transform(getCreatorMap().begin(), getCreatorMap().end(), std::back_inserter(result), getter); std::transform(getMap().begin(), getMap().end(), std::back_inserter(result), getter);
std::transform(aliases.begin(), aliases.end(), std::back_inserter(result), getter); std::transform(aliases.begin(), aliases.end(), std::back_inserter(result), getter);
return result; return result;
} }
@ -88,7 +88,7 @@ public:
bool isCaseInsensitive(const String & name) const bool isCaseInsensitive(const String & name) const
{ {
String name_lowercase = Poco::toLower(name); String name_lowercase = Poco::toLower(name);
return getCaseInsensitiveCreatorMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase); return getCaseInsensitiveMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase);
} }
const String & aliasTo(const String & name) const const String & aliasTo(const String & name) const
@ -109,11 +109,11 @@ public:
virtual ~IFactoryWithAliases() override {} virtual ~IFactoryWithAliases() override {}
private: private:
using InnerMap = std::unordered_map<String, Creator>; // name -> creator using InnerMap = std::unordered_map<String, Value>; // name -> creator
using AliasMap = std::unordered_map<String, String>; // alias -> original type using AliasMap = std::unordered_map<String, String>; // alias -> original type
virtual const InnerMap & getCreatorMap() const = 0; virtual const InnerMap & getMap() const = 0;
virtual const InnerMap & getCaseInsensitiveCreatorMap() const = 0; virtual const InnerMap & getCaseInsensitiveMap() const = 0;
virtual String getFactoryName() const = 0; virtual String getFactoryName() const = 0;
/// Alias map to data_types from previous two maps /// Alias map to data_types from previous two maps

View File

@ -1,4 +1,4 @@
set(SRCS) set(SRCS)
add_executable (finish_sorting_stream finish_sorting_stream.cpp ${SRCS}) add_executable (finish_sorting_stream finish_sorting_stream.cpp ${SRCS})
target_link_libraries (finish_sorting_stream PRIVATE dbms) target_link_libraries (finish_sorting_stream PRIVATE clickhouse_aggregate_functions dbms)

View File

@ -392,7 +392,8 @@ static DataTypePtr create(const ASTPtr & arguments)
if (function_name.empty()) if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties);
return std::make_shared<DataTypeAggregateFunction>(function, argument_types, params_row); return std::make_shared<DataTypeAggregateFunction>(function, argument_types, params_row);
} }

View File

@ -110,7 +110,8 @@ static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & argum
if (function_name.empty()) if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties);
// check function // check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))

View File

@ -80,7 +80,7 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr
} }
void DataTypeFactory::registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness) void DataTypeFactory::registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness)
{ {
if (creator == nullptr) if (creator == nullptr)
throw Exception("DataTypeFactory: the data type family " + family_name + " has been provided " throw Exception("DataTypeFactory: the data type family " + family_name + " has been provided "
@ -136,7 +136,7 @@ void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCre
}, case_sensitiveness); }, case_sensitiveness);
} }
const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const const DataTypeFactory::Value & DataTypeFactory::findCreatorByName(const String & family_name) const
{ {
{ {
DataTypesDictionary::const_iterator it = data_types.find(family_name); DataTypesDictionary::const_iterator it = data_types.find(family_name);

View File

@ -23,7 +23,7 @@ class DataTypeFactory final : private boost::noncopyable, public IFactoryWithAli
{ {
private: private:
using SimpleCreator = std::function<DataTypePtr()>; using SimpleCreator = std::function<DataTypePtr()>;
using DataTypesDictionary = std::unordered_map<String, Creator>; using DataTypesDictionary = std::unordered_map<String, Value>;
using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>; using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>;
using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>; using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>;
@ -35,7 +35,7 @@ public:
DataTypePtr get(const ASTPtr & ast) const; DataTypePtr get(const ASTPtr & ast) const;
/// Register a type family by its name. /// Register a type family by its name.
void registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); void registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Register a simple data type, that have no parameters. /// Register a simple data type, that have no parameters.
void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
@ -47,7 +47,7 @@ public:
void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive); void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
private: private:
const Creator& findCreatorByName(const String & family_name) const; const Value & findCreatorByName(const String & family_name) const;
private: private:
DataTypesDictionary data_types; DataTypesDictionary data_types;
@ -57,9 +57,9 @@ private:
DataTypeFactory(); DataTypeFactory();
const DataTypesDictionary & getCreatorMap() const override { return data_types; } const DataTypesDictionary & getMap() const override { return data_types; }
const DataTypesDictionary & getCaseInsensitiveCreatorMap() const override { return case_insensitive_data_types; } const DataTypesDictionary & getCaseInsensitiveMap() const override { return case_insensitive_data_types; }
String getFactoryName() const override { return "DataTypeFactory"; } String getFactoryName() const override { return "DataTypeFactory"; }
}; };

View File

@ -1,4 +1,4 @@
set(SRCS ) set(SRCS )
add_executable (tab_separated_streams tab_separated_streams.cpp ${SRCS}) add_executable (tab_separated_streams tab_separated_streams.cpp ${SRCS})
target_link_libraries (tab_separated_streams PRIVATE dbms) target_link_libraries (tab_separated_streams PRIVATE clickhouse_aggregate_functions dbms)

View File

@ -20,7 +20,7 @@ namespace ErrorCodes
void FunctionFactory::registerFunction(const void FunctionFactory::registerFunction(const
std::string & name, std::string & name,
Creator creator, Value creator,
CaseSensitiveness case_sensitiveness) CaseSensitiveness case_sensitiveness)
{ {
if (!functions.emplace(name, creator).second) if (!functions.emplace(name, creator).second)

View File

@ -53,7 +53,7 @@ public:
FunctionOverloadResolverImplPtr tryGetImpl(const std::string & name, const Context & context) const; FunctionOverloadResolverImplPtr tryGetImpl(const std::string & name, const Context & context) const;
private: private:
using Functions = std::unordered_map<std::string, Creator>; using Functions = std::unordered_map<std::string, Value>;
Functions functions; Functions functions;
Functions case_insensitive_functions; Functions case_insensitive_functions;
@ -64,9 +64,9 @@ private:
return std::make_unique<DefaultOverloadResolver>(Function::create(context)); return std::make_unique<DefaultOverloadResolver>(Function::create(context));
} }
const Functions & getCreatorMap() const override { return functions; } const Functions & getMap() const override { return functions; }
const Functions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_functions; } const Functions & getCaseInsensitiveMap() const override { return case_insensitive_functions; }
String getFactoryName() const override { return "FunctionFactory"; } String getFactoryName() const override { return "FunctionFactory"; }
@ -74,7 +74,7 @@ private:
/// No locking, you must register all functions before usage of get. /// No locking, you must register all functions before usage of get.
void registerFunction( void registerFunction(
const std::string & name, const std::string & name,
Creator creator, Value creator,
CaseSensitiveness case_sensitiveness = CaseSensitive); CaseSensitiveness case_sensitiveness = CaseSensitive);
}; };

View File

@ -113,8 +113,9 @@ public:
auto nested_type = array_type->getNestedType(); auto nested_type = array_type->getNestedType();
DataTypes argument_types = {nested_type}; DataTypes argument_types = {nested_type};
Array params_row; Array params_row;
AggregateFunctionPtr bitmap_function AggregateFunctionProperties properties;
= AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row); AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get(
AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row, properties);
return std::make_shared<DataTypeAggregateFunction>(bitmap_function, argument_types, params_row); return std::make_shared<DataTypeAggregateFunction>(bitmap_function, argument_types, params_row);
} }
@ -156,8 +157,9 @@ private:
// output data // output data
Array params_row; Array params_row;
AggregateFunctionPtr bitmap_function AggregateFunctionProperties properties;
= AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row); AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get(
AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row, properties);
auto col_to = ColumnAggregateFunction::create(bitmap_function); auto col_to = ColumnAggregateFunction::create(bitmap_function);
col_to->reserve(offsets.size()); col_to->reserve(offsets.size());

View File

@ -97,7 +97,8 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName()); aggregate_function_name, params_row, "function " + getName());
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row); AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
} }
return aggregate_function->getReturnType(); return aggregate_function->getReturnType();

View File

@ -115,7 +115,8 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName()); aggregate_function_name, params_row, "function " + getName());
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row); AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
} }
return std::make_shared<DataTypeArray>(aggregate_function->getReturnType()); return std::make_shared<DataTypeArray>(aggregate_function->getReturnType());

View File

@ -420,8 +420,9 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ExpressionActionsPtr & action
aggregate.argument_names[i] = name; aggregate.argument_names[i] = name;
} }
AggregateFunctionProperties properties;
aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array(); aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array();
aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters); aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters, properties);
aggregate_descriptions.push_back(aggregate); aggregate_descriptions.push_back(aggregate);
} }

View File

@ -34,11 +34,11 @@ target_include_directories (two_level_hash_map SYSTEM BEFORE PRIVATE ${SPARSEHAS
target_link_libraries (two_level_hash_map PRIVATE dbms) target_link_libraries (two_level_hash_map PRIVATE dbms)
add_executable (in_join_subqueries_preprocessor in_join_subqueries_preprocessor.cpp) add_executable (in_join_subqueries_preprocessor in_join_subqueries_preprocessor.cpp)
target_link_libraries (in_join_subqueries_preprocessor PRIVATE dbms clickhouse_parsers) target_link_libraries (in_join_subqueries_preprocessor PRIVATE clickhouse_aggregate_functions dbms clickhouse_parsers)
add_check(in_join_subqueries_preprocessor) add_check(in_join_subqueries_preprocessor)
add_executable (users users.cpp) add_executable (users users.cpp)
target_link_libraries (users PRIVATE dbms clickhouse_common_config) target_link_libraries (users PRIVATE clickhouse_aggregate_functions dbms clickhouse_common_config)
if (OS_LINUX) if (OS_LINUX)
add_executable (internal_iotop internal_iotop.cpp) add_executable (internal_iotop internal_iotop.cpp)

View File

@ -103,9 +103,10 @@ int main(int argc, char ** argv)
std::vector<Key> data(n); std::vector<Key> data(n);
Value value; Value value;
AggregateFunctionPtr func_count = factory.get("count", data_types_empty); AggregateFunctionProperties properties;
AggregateFunctionPtr func_avg = factory.get("avg", data_types_uint64); AggregateFunctionPtr func_count = factory.get("count", data_types_empty, {}, properties);
AggregateFunctionPtr func_uniq = factory.get("uniq", data_types_uint64); AggregateFunctionPtr func_avg = factory.get("avg", data_types_uint64, {}, properties);
AggregateFunctionPtr func_uniq = factory.get("uniq", data_types_uint64, {}, properties);
#define INIT \ #define INIT \
{ \ { \

View File

@ -47,7 +47,8 @@ struct SummingSortedAlgorithm::AggregateDescription
void init(const char * function_name, const DataTypes & argument_types) void init(const char * function_name, const DataTypes & argument_types)
{ {
function = AggregateFunctionFactory::instance().get(function_name, argument_types); AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, {}, properties);
add_function = function->getAddressOfAddFunction(); add_function = function->getAddressOfAddFunction();
state.reset(function->sizeOfData(), function->alignOfData()); state.reset(function->sizeOfData(), function->alignOfData());
} }

View File

@ -117,8 +117,9 @@ static void appendGraphitePattern(
aggregate_function_name, params_row, "GraphiteMergeTree storage initialization"); aggregate_function_name, params_row, "GraphiteMergeTree storage initialization");
/// TODO Not only Float64 /// TODO Not only Float64
pattern.function = AggregateFunctionFactory::instance().get(aggregate_function_name, {std::make_shared<DataTypeFloat64>()}, AggregateFunctionProperties properties;
params_row); pattern.function = AggregateFunctionFactory::instance().get(
aggregate_function_name, {std::make_shared<DataTypeFloat64>()}, params_row, properties);
} }
else if (startsWith(key, "retention")) else if (startsWith(key, "retention"))
{ {

View File

@ -15,7 +15,7 @@ namespace ErrorCodes
} }
void TableFunctionFactory::registerFunction(const std::string & name, Creator creator, CaseSensitiveness case_sensitiveness) void TableFunctionFactory::registerFunction(const std::string & name, Value creator, CaseSensitiveness case_sensitiveness)
{ {
if (!table_functions.emplace(name, creator).second) if (!table_functions.emplace(name, creator).second)
throw Exception("TableFunctionFactory: the table function name '" + name + "' is not unique", throw Exception("TableFunctionFactory: the table function name '" + name + "' is not unique",

View File

@ -24,12 +24,11 @@ using TableFunctionCreator = std::function<TableFunctionPtr()>;
class TableFunctionFactory final: private boost::noncopyable, public IFactoryWithAliases<TableFunctionCreator> class TableFunctionFactory final: private boost::noncopyable, public IFactoryWithAliases<TableFunctionCreator>
{ {
public: public:
static TableFunctionFactory & instance(); static TableFunctionFactory & instance();
/// Register a function by its name. /// Register a function by its name.
/// No locking, you must register all functions before usage of get. /// No locking, you must register all functions before usage of get.
void registerFunction(const std::string & name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); void registerFunction(const std::string & name, Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
template <typename Function> template <typename Function>
void registerFunction(CaseSensitiveness case_sensitiveness = CaseSensitive) void registerFunction(CaseSensitiveness case_sensitiveness = CaseSensitive)
@ -50,11 +49,11 @@ public:
bool isTableFunctionName(const std::string & name) const; bool isTableFunctionName(const std::string & name) const;
private: private:
using TableFunctions = std::unordered_map<std::string, Creator>; using TableFunctions = std::unordered_map<std::string, Value>;
const TableFunctions & getCreatorMap() const override { return table_functions; } const TableFunctions & getMap() const override { return table_functions; }
const TableFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_table_functions; } const TableFunctions & getCaseInsensitiveMap() const override { return case_insensitive_table_functions; }
String getFactoryName() const override { return "TableFunctionFactory"; } String getFactoryName() const override { return "TableFunctionFactory"; }

View File

@ -5,5 +5,5 @@
5 5
5 5
0 0
\N 0
\N 0

View File

@ -7,6 +7,5 @@ SELECT uniqExact(number >= 5 ? number : NULL) FROM numbers(10);
SELECT count(DISTINCT number >= 5 ? number : NULL) FROM numbers(10); SELECT count(DISTINCT number >= 5 ? number : NULL) FROM numbers(10);
SELECT count(NULL); SELECT count(NULL);
-- These two returns NULL for now, but we want to change them to return 0.
SELECT uniq(NULL); SELECT uniq(NULL);
SELECT count(DISTINCT NULL); SELECT count(DISTINCT NULL);

View File

@ -1,2 +1,2 @@
add_executable (convert-month-partitioned-parts main.cpp) add_executable (convert-month-partitioned-parts main.cpp)
target_link_libraries(convert-month-partitioned-parts PRIVATE dbms clickhouse_parsers boost::program_options) target_link_libraries(convert-month-partitioned-parts PRIVATE clickhouse_aggregate_functions dbms clickhouse_parsers boost::program_options)

View File

@ -1,3 +1,3 @@
add_executable (zookeeper-adjust-block-numbers-to-parts main.cpp ${SRCS}) add_executable (zookeeper-adjust-block-numbers-to-parts main.cpp ${SRCS})
target_compile_options(zookeeper-adjust-block-numbers-to-parts PRIVATE -Wno-format) target_compile_options(zookeeper-adjust-block-numbers-to-parts PRIVATE -Wno-format)
target_link_libraries (zookeeper-adjust-block-numbers-to-parts PRIVATE dbms clickhouse_common_zookeeper boost::program_options) target_link_libraries (zookeeper-adjust-block-numbers-to-parts PRIVATE clickhouse_aggregate_functions dbms clickhouse_common_zookeeper boost::program_options)