Merge pull request #47716 from ClickHouse/prevent-slow-aggregate-combinators

Prevent too long (slow) aggregate function combinators
This commit is contained in:
Alexey Milovidov 2023-03-19 17:38:53 +03:00 committed by GitHub
commit 02b8d2bbf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 46 deletions

View File

@ -2,12 +2,10 @@
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
@ -21,6 +19,9 @@
#include <Functions/FunctionFactory.h>
static constexpr size_t MAX_AGGREGATE_FUNCTION_NAME_LENGTH = 1000;
namespace DB
{
struct Settings;
@ -30,6 +31,7 @@ namespace ErrorCodes
extern const int UNKNOWN_AGGREGATE_FUNCTION;
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_AGGREGATION;
extern const int TOO_LARGE_STRING_SIZE;
}
const String & getAggregateFunctionCanonicalNameIfAny(const String & name)
@ -70,12 +72,17 @@ static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types)
AggregateFunctionPtr AggregateFunctionFactory::get(
const String & name, const DataTypes & argument_types, const Array & parameters, AggregateFunctionProperties & out_properties) const
{
/// This to prevent costly string manipulation in parsing the aggregate function combinators.
/// Example: avgArrayArrayArrayArray...(1000 times)...Array
if (name.size() > MAX_AGGREGATE_FUNCTION_NAME_LENGTH)
throw Exception(ErrorCodes::TOO_LARGE_STRING_SIZE, "Too long name of aggregate function, maximum: {}", MAX_AGGREGATE_FUNCTION_NAME_LENGTH);
auto types_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
/// If one of the types is Nullable, we apply aggregate function combinator "Null" if it's not window function.
/// Window functions are not real aggregate functions. Applying combinators doesn't make sense for them,
/// they must handle the nullability themselves
auto properties = tryGetPropertiesImpl(name);
auto properties = tryGetProperties(name);
bool is_window_function = properties.has_value() && properties->is_window_function;
if (!is_window_function && std::any_of(types_without_low_cardinality.begin(), types_without_low_cardinality.end(),
[](const auto & type) { return type->isNullable(); }))
@ -216,61 +223,67 @@ AggregateFunctionPtr AggregateFunctionFactory::tryGet(
}
std::optional<AggregateFunctionProperties> AggregateFunctionFactory::tryGetPropertiesImpl(const String & name_param) const
std::optional<AggregateFunctionProperties> AggregateFunctionFactory::tryGetProperties(String name) const
{
String name = getAliasToOrName(name_param);
Value found;
if (name.size() > MAX_AGGREGATE_FUNCTION_NAME_LENGTH)
throw Exception(ErrorCodes::TOO_LARGE_STRING_SIZE, "Too long name of aggregate function, maximum: {}", MAX_AGGREGATE_FUNCTION_NAME_LENGTH);
/// Find by exact match.
if (auto it = aggregate_functions.find(name); it != aggregate_functions.end())
while (true)
{
found = it->second;
}
name = getAliasToOrName(name);
Value found;
if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
found = jt->second;
/// Find by exact match.
if (auto it = aggregate_functions.find(name); it != aggregate_functions.end())
{
found = it->second;
}
if (found.creator)
return found.properties;
if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
found = jt->second;
/// Combinators of aggregate functions.
/// For every aggregate function 'agg' and combiner '-Comb' there is a combined aggregate function with the name 'aggComb',
/// that can have different number and/or types of arguments, different result type and different behaviour.
if (found.creator)
return found.properties;
if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name))
{
if (combinator->isForInternalUsageOnly())
/// Combinators of aggregate functions.
/// For every aggregate function 'agg' and combiner '-Comb' there is a combined aggregate function with the name 'aggComb',
/// that can have different number and/or types of arguments, different result type and different behaviour.
if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name))
{
if (combinator->isForInternalUsageOnly())
return {};
/// NOTE: It's reasonable to also allow to transform properties by combinator.
name = name.substr(0, name.size() - combinator->getName().size());
}
else
return {};
String nested_name = name.substr(0, name.size() - combinator->getName().size());
/// NOTE: It's reasonable to also allow to transform properties by combinator.
return tryGetPropertiesImpl(nested_name);
}
return {};
}
std::optional<AggregateFunctionProperties> AggregateFunctionFactory::tryGetProperties(const String & name) const
bool AggregateFunctionFactory::isAggregateFunctionName(String name) const
{
return tryGetPropertiesImpl(name);
}
if (name.size() > MAX_AGGREGATE_FUNCTION_NAME_LENGTH)
throw Exception(ErrorCodes::TOO_LARGE_STRING_SIZE, "Too long name of aggregate function, maximum: {}", MAX_AGGREGATE_FUNCTION_NAME_LENGTH);
while (true)
{
if (aggregate_functions.contains(name) || isAlias(name))
return true;
bool AggregateFunctionFactory::isAggregateFunctionName(const String & name) const
{
if (aggregate_functions.contains(name) || isAlias(name))
return true;
String name_lowercase = Poco::toLower(name);
if (case_insensitive_aggregate_functions.contains(name_lowercase) || isAlias(name_lowercase))
return true;
String name_lowercase = Poco::toLower(name);
if (case_insensitive_aggregate_functions.contains(name_lowercase) || isAlias(name_lowercase))
return true;
if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name))
return isAggregateFunctionName(name.substr(0, name.size() - combinator->getName().size()));
return false;
if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name))
{
name = name.substr(0, name.size() - combinator->getName().size());
}
else
return false;
}
}
AggregateFunctionFactory & AggregateFunctionFactory::instance()

View File

@ -77,9 +77,9 @@ public:
AggregateFunctionProperties & out_properties) const;
/// Get properties if the aggregate function exists.
std::optional<AggregateFunctionProperties> tryGetProperties(const String & name) const;
std::optional<AggregateFunctionProperties> tryGetProperties(String name) const;
bool isAggregateFunctionName(const String & name) const;
bool isAggregateFunctionName(String name) const;
private:
AggregateFunctionPtr getImpl(
@ -89,8 +89,6 @@ private:
AggregateFunctionProperties & out_properties,
bool has_null_arguments) const;
std::optional<AggregateFunctionProperties> tryGetPropertiesImpl(const String & name) const;
using AggregateFunctions = std::unordered_map<String, Value>;
AggregateFunctions aggregate_functions;

File diff suppressed because one or more lines are too long