Merge pull request #23159 from ClickHouse/aku/merge-fusecount

merging sumCount fusion PR #21337
This commit is contained in:
Alexander Kuzmenkov 2021-04-19 16:47:13 +03:00 committed by GitHub
commit 2a4bcb6e3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 290 additions and 3 deletions

View File

@ -66,7 +66,12 @@ reportStageEnd('parse')
subst_elems = root.findall('substitutions/substitution')
available_parameters = {} # { 'table': ['hits_10m', 'hits_100m'], ... }
for e in subst_elems:
available_parameters[e.find('name').text] = [v.text for v in e.findall('values/value')]
name = e.find('name').text
values = [v.text for v in e.findall('values/value')]
if not values:
raise Exception(f'No values given for substitution {{{name}}}')
available_parameters[name] = values
# Takes parallel lists of templates, substitutes them with all combos of
# parameters. The set of parameters is determined based on the first list.

View File

@ -96,7 +96,7 @@ public:
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}), num_scale(num_scale_), denom_scale(denom_scale_) {}
DataTypePtr getReturnType() const final { return std::make_shared<DataTypeNumber<Float64>>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
bool allocatesMemoryInArena() const override { return false; }

View File

@ -0,0 +1,49 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionSumCount.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include "registerAggregateFunctions.h"
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
namespace
{
bool allowType(const DataTypePtr& type) noexcept
{
const WhichDataType t(type);
return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
}
AggregateFunctionPtr createAggregateFunctionSumCount(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (!allowType(data_type))
throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFunctionSumCount>(
*data_type, argument_types, getDecimalScale(*data_type)));
else
res.reset(createWithNumericType<AggregateFunctionSumCount>(*data_type, argument_types));
return res;
}
}
void registerAggregateFunctionSumCount(AggregateFunctionFactory & factory)
{
factory.registerFunction("sumCount", createAggregateFunctionSumCount);
}
}

View File

@ -0,0 +1,55 @@
#pragma once
#include <type_traits>
#include <DataTypes/DataTypeTuple.h>
#include <AggregateFunctions/AggregateFunctionAvg.h>
namespace DB
{
template <typename T>
using DecimalOrNumberDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<AvgFieldType<T>>, DataTypeNumber<AvgFieldType<T>>>;
template <typename T>
class AggregateFunctionSumCount final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionSumCount<T>>
{
public:
using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionSumCount<T>>;
AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
: Base(argument_types_, num_scale_), scale(num_scale_) {}
DataTypePtr getReturnType() const override
{
DataTypes types;
if constexpr (IsDecimalNumber<T>)
types.emplace_back(std::make_shared<DecimalOrNumberDataType<T>>(DecimalOrNumberDataType<T>::maxPrecision(), scale));
else
types.emplace_back(std::make_shared<DecimalOrNumberDataType<T>>());
types.emplace_back(std::make_shared<DataTypeUInt64>());
return std::make_shared<DataTypeTuple>(types);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final
{
assert_cast<DecimalOrVectorCol<AvgFieldType<T>> &>((assert_cast<ColumnTuple &>(to)).getColumn(0)).getData().push_back(
this->data(place).numerator);
assert_cast<ColumnUInt64 &>((assert_cast<ColumnTuple &>(to)).getColumn(1)).getData().push_back(
this->data(place).denominator);
}
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
{
this->data(place).numerator += static_cast<const DecimalOrVectorCol<T> &>(*columns[0]).getData()[row_num];
++this->data(place).denominator;
}
String getName() const final { return "sumCount"; }
private:
UInt32 scale;
};
}

View File

@ -25,6 +25,7 @@ void registerAggregateFunctionsAny(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory &);
void registerAggregateFunctionSum(AggregateFunctionFactory &);
void registerAggregateFunctionSumCount(AggregateFunctionFactory &);
void registerAggregateFunctionSumMap(AggregateFunctionFactory &);
void registerAggregateFunctionsUniq(AggregateFunctionFactory &);
void registerAggregateFunctionUniqCombined(AggregateFunctionFactory &);
@ -83,6 +84,7 @@ void registerAggregateFunctions()
registerAggregateFunctionsStatisticsStable(factory);
registerAggregateFunctionsStatisticsSimple(factory);
registerAggregateFunctionSum(factory);
registerAggregateFunctionSumCount(factory);
registerAggregateFunctionSumMap(factory);
registerAggregateFunctionsUniq(factory);
registerAggregateFunctionUniqCombined(factory);

View File

@ -424,6 +424,7 @@ class IColumn;
M(Bool, allow_non_metadata_alters, true, "Allow to execute alters which affects not only tables metadata, but also data on disk", 0) \
M(Bool, enable_global_with_statement, true, "Propagate WITH statements to UNION queries and all subqueries", 0) \
M(Bool, aggregate_functions_null_for_empty, false, "Rewrite all aggregate functions in a query, adding -OrNull suffix to them", 0) \
M(Bool, optimize_fuse_sum_count_avg, false, "Fuse aggregate functions sum(), avg(), count() with identical arguments into one sumCount() call, if the query has at least two different functions", 0) \
M(Bool, flatten_nested, true, "If true, columns of type Nested will be flatten to separate array columns instead of one array of tuples", 0) \
M(Bool, asterisk_include_materialized_columns, false, "Include MATERIALIZED columns for wildcard query", 0) \
M(Bool, asterisk_include_alias_columns, false, "Include ALIAS columns for wildcard query", 0) \

View File

@ -26,6 +26,7 @@
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/queryToString.h>
#include <Parsers/ASTLiteral.h>
#include <DataTypes/NestedUtils.h>
#include <DataTypes/DataTypeNullable.h>
@ -181,8 +182,72 @@ struct CustomizeAggregateFunctionsMoveSuffixData
}
};
struct FuseSumCountAggregates
{
std::vector<ASTFunction *> sums {};
std::vector<ASTFunction *> counts {};
std::vector<ASTFunction *> avgs {};
void addFuncNode(ASTFunction * func)
{
if (func->name == "sum")
sums.push_back(func);
else if (func->name == "count")
counts.push_back(func);
else
{
assert(func->name == "avg");
avgs.push_back(func);
}
}
bool canBeFused() const
{
// Need at least two different kinds of functions to fuse.
if (sums.empty() && counts.empty())
return false;
if (sums.empty() && avgs.empty())
return false;
if (counts.empty() && avgs.empty())
return false;
return true;
}
};
struct FuseSumCountAggregatesVisitorData
{
using TypeToVisit = ASTFunction;
std::unordered_map<String, FuseSumCountAggregates> fuse_map;
void visit(ASTFunction & func, ASTPtr &)
{
if (func.name == "sum" || func.name == "avg" || func.name == "count")
{
if (func.arguments->children.empty())
return;
// Probably we can extend it to match count() for non-nullable argument
// to sum/avg with any other argument. Now we require strict match.
const auto argument = func.arguments->children.at(0)->getColumnName();
auto it = fuse_map.find(argument);
if (it != fuse_map.end())
{
it->second.addFuncNode(&func);
}
else
{
FuseSumCountAggregates funcs{};
funcs.addFuncNode(&func);
fuse_map[argument] = funcs;
}
}
}
};
using CustomizeAggregateFunctionsOrNullVisitor = InDepthNodeVisitor<OneTypeMatcher<CustomizeAggregateFunctionsSuffixData>, true>;
using CustomizeAggregateFunctionsMoveOrNullVisitor = InDepthNodeVisitor<OneTypeMatcher<CustomizeAggregateFunctionsMoveSuffixData>, true>;
using FuseSumCountAggregatesVisitor = InDepthNodeVisitor<OneTypeMatcher<FuseSumCountAggregatesVisitorData>, true>;
/// Translate qualified names such as db.table.column, table.column, table_alias.column to names' normal form.
/// Expand asterisks and qualified asterisks with column names.
@ -200,6 +265,49 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query
throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED);
}
// Replaces one avg/sum/count function with an appropriate expression with
// sumCount().
void replaceWithSumCount(String column_name, ASTFunction & func)
{
auto func_base = makeASTFunction("sumCount", std::make_shared<ASTIdentifier>(column_name));
auto exp_list = std::make_shared<ASTExpressionList>();
if (func.name == "sum" || func.name == "count")
{
/// Rewrite "sum" to sumCount().1, rewrite "count" to sumCount().2
UInt8 idx = (func.name == "sum" ? 1 : 2);
func.name = "tupleElement";
exp_list->children.push_back(func_base);
exp_list->children.push_back(std::make_shared<ASTLiteral>(idx));
}
else
{
/// Rewrite "avg" to sumCount().1 / sumCount().2
auto new_arg1 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(1)));
auto new_arg2 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(2)));
func.name = "divide";
exp_list->children.push_back(new_arg1);
exp_list->children.push_back(new_arg2);
}
func.arguments = exp_list;
func.children.push_back(func.arguments);
}
void fuseSumCountAggregates(std::unordered_map<String, FuseSumCountAggregates> & fuse_map)
{
for (auto & it : fuse_map)
{
if (it.second.canBeFused())
{
for (auto & func: it.second.sums)
replaceWithSumCount(it.first, *func);
for (auto & func: it.second.avgs)
replaceWithSumCount(it.first, *func);
for (auto & func: it.second.counts)
replaceWithSumCount(it.first, *func);
}
}
}
bool hasArrayJoin(const ASTPtr & ast)
{
if (const ASTFunction * function = ast->as<ASTFunction>())
@ -910,7 +1018,18 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const NameSet &
CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query);
}
// Rewrite all aggregate functions to add -OrNull suffix to them
// Try to fuse sum/avg/count with identical arguments to one sumCount call,
// if we have at least two different functions. E.g. we will replace sum(x)
// and count(x) with sumCount(x).1 and sumCount(x).2, and sumCount() will
// be calculated only once because of CSE.
if (settings.optimize_fuse_sum_count_avg)
{
FuseSumCountAggregatesVisitor::Data data;
FuseSumCountAggregatesVisitor(data).visit(query);
fuseSumCountAggregates(data.fuse_map);
}
/// Rewrite all aggregate functions to add -OrNull suffix to them
if (settings.aggregate_functions_null_for_empty)
{
CustomizeAggregateFunctionsOrNullVisitor::Data data_or_null{"OrNull"};

View File

@ -0,0 +1,33 @@
<test>
<!-- We test rewriting sum(), avg(), count() to a single call of sumCount() here.
As a reference, we use the same queries with the optimization disabled.
sum() has a highly optimized algorithm, so alone it will be faster than sumCount(),
but when we add count() or avg(), the sumCount() should win.
Also test GROUP BY with and without keys, because they might have different
optimizations. -->
<settings>
<optimize_fuse_sum_count_avg>1</optimize_fuse_sum_count_avg>
</settings>
<substitutions>
<substitution>
<name>key</name>
<values>
<value>1</value>
<value>intHash32(number) % 1000</value>
</values>
</substitution>
</substitutions>
<query>SELECT sum(number) FROM numbers(1000000000) FORMAT Null</query>
<query>SELECT sum(number), count(number) FROM numbers(1000000000) FORMAT Null</query>
<query>SELECT sum(number), count(number) FROM numbers(1000000000) SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
<query>SELECT sum(number), avg(number), count(number) FROM numbers(1000000000) FORMAT Null</query>
<query>SELECT sum(number), avg(number), count(number) FROM numbers(1000000000) SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
<query>SELECT sum(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 FORMAT Null</query>
<query>SELECT sum(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 FORMAT Null</query>
<query>SELECT sum(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
<query>SELECT sum(number), avg(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 FORMAT Null</query>
<query>SELECT sum(number), avg(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
</test>

View File

@ -0,0 +1,12 @@
210 230 20
SELECT
sum(a),
sumCount(b).1,
sumCount(b).2
FROM fuse_tbl
---------NOT trigger fuse--------
210 11.5
SELECT
sum(a),
avg(b)
FROM fuse_tbl

View File

@ -0,0 +1,11 @@
DROP TABLE IF EXISTS fuse_tbl;
CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log;
INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20);
SET optimize_fuse_sum_count_avg = 1;
SELECT sum(a), sum(b), count(b) from fuse_tbl;
EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl;
SELECT '---------NOT trigger fuse--------';
SELECT sum(a), avg(b) from fuse_tbl;
EXPLAIN SYNTAX SELECT sum(a), avg(b) from fuse_tbl;
DROP TABLE fuse_tbl;