mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Merge pull request #23159 from ClickHouse/aku/merge-fusecount
merging sumCount fusion PR #21337
This commit is contained in:
commit
2a4bcb6e3f
@ -66,7 +66,12 @@ reportStageEnd('parse')
|
||||
subst_elems = root.findall('substitutions/substitution')
|
||||
available_parameters = {} # { 'table': ['hits_10m', 'hits_100m'], ... }
|
||||
for e in subst_elems:
|
||||
available_parameters[e.find('name').text] = [v.text for v in e.findall('values/value')]
|
||||
name = e.find('name').text
|
||||
values = [v.text for v in e.findall('values/value')]
|
||||
if not values:
|
||||
raise Exception(f'No values given for substitution {{{name}}}')
|
||||
|
||||
available_parameters[name] = values
|
||||
|
||||
# Takes parallel lists of templates, substitutes them with all combos of
|
||||
# parameters. The set of parameters is determined based on the first list.
|
||||
|
@ -96,7 +96,7 @@ public:
|
||||
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
|
||||
: Base(argument_types_, {}), num_scale(num_scale_), denom_scale(denom_scale_) {}
|
||||
|
||||
DataTypePtr getReturnType() const final { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
|
49
src/AggregateFunctions/AggregateFunctionSumCount.cpp
Normal file
49
src/AggregateFunctions/AggregateFunctionSumCount.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSumCount.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
bool allowType(const DataTypePtr& type) noexcept
|
||||
{
|
||||
const WhichDataType t(type);
|
||||
return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSumCount(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (!allowType(data_type))
|
||||
throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
if (isDecimal(data_type))
|
||||
res.reset(createWithDecimalType<AggregateFunctionSumCount>(
|
||||
*data_type, argument_types, getDecimalScale(*data_type)));
|
||||
else
|
||||
res.reset(createWithNumericType<AggregateFunctionSumCount>(*data_type, argument_types));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionSumCount(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("sumCount", createAggregateFunctionSumCount);
|
||||
}
|
||||
|
||||
}
|
55
src/AggregateFunctions/AggregateFunctionSumCount.h
Normal file
55
src/AggregateFunctions/AggregateFunctionSumCount.h
Normal file
@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <AggregateFunctions/AggregateFunctionAvg.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template <typename T>
|
||||
using DecimalOrNumberDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<AvgFieldType<T>>, DataTypeNumber<AvgFieldType<T>>>;
|
||||
template <typename T>
|
||||
class AggregateFunctionSumCount final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionSumCount<T>>
|
||||
{
|
||||
public:
|
||||
using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionSumCount<T>>;
|
||||
|
||||
AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
|
||||
: Base(argument_types_, num_scale_), scale(num_scale_) {}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
DataTypes types;
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
types.emplace_back(std::make_shared<DecimalOrNumberDataType<T>>(DecimalOrNumberDataType<T>::maxPrecision(), scale));
|
||||
else
|
||||
types.emplace_back(std::make_shared<DecimalOrNumberDataType<T>>());
|
||||
|
||||
types.emplace_back(std::make_shared<DataTypeUInt64>());
|
||||
|
||||
return std::make_shared<DataTypeTuple>(types);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final
|
||||
{
|
||||
assert_cast<DecimalOrVectorCol<AvgFieldType<T>> &>((assert_cast<ColumnTuple &>(to)).getColumn(0)).getData().push_back(
|
||||
this->data(place).numerator);
|
||||
|
||||
assert_cast<ColumnUInt64 &>((assert_cast<ColumnTuple &>(to)).getColumn(1)).getData().push_back(
|
||||
this->data(place).denominator);
|
||||
}
|
||||
|
||||
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
|
||||
{
|
||||
this->data(place).numerator += static_cast<const DecimalOrVectorCol<T> &>(*columns[0]).getData()[row_num];
|
||||
++this->data(place).denominator;
|
||||
}
|
||||
|
||||
String getName() const final { return "sumCount"; }
|
||||
|
||||
private:
|
||||
UInt32 scale;
|
||||
};
|
||||
|
||||
}
|
@ -25,6 +25,7 @@ void registerAggregateFunctionsAny(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSum(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSumCount(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSumMap(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsUniq(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionUniqCombined(AggregateFunctionFactory &);
|
||||
@ -83,6 +84,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionsStatisticsStable(factory);
|
||||
registerAggregateFunctionsStatisticsSimple(factory);
|
||||
registerAggregateFunctionSum(factory);
|
||||
registerAggregateFunctionSumCount(factory);
|
||||
registerAggregateFunctionSumMap(factory);
|
||||
registerAggregateFunctionsUniq(factory);
|
||||
registerAggregateFunctionUniqCombined(factory);
|
||||
|
@ -424,6 +424,7 @@ class IColumn;
|
||||
M(Bool, allow_non_metadata_alters, true, "Allow to execute alters which affects not only tables metadata, but also data on disk", 0) \
|
||||
M(Bool, enable_global_with_statement, true, "Propagate WITH statements to UNION queries and all subqueries", 0) \
|
||||
M(Bool, aggregate_functions_null_for_empty, false, "Rewrite all aggregate functions in a query, adding -OrNull suffix to them", 0) \
|
||||
M(Bool, optimize_fuse_sum_count_avg, false, "Fuse aggregate functions sum(), avg(), count() with identical arguments into one sumCount() call, if the query has at least two different functions", 0) \
|
||||
M(Bool, flatten_nested, true, "If true, columns of type Nested will be flatten to separate array columns instead of one array of tuples", 0) \
|
||||
M(Bool, asterisk_include_materialized_columns, false, "Include MATERIALIZED columns for wildcard query", 0) \
|
||||
M(Bool, asterisk_include_alias_columns, false, "Include ALIAS columns for wildcard query", 0) \
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <Parsers/ASTSelectQuery.h>
|
||||
#include <Parsers/ASTTablesInSelectQuery.h>
|
||||
#include <Parsers/queryToString.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
|
||||
#include <DataTypes/NestedUtils.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
@ -181,8 +182,72 @@ struct CustomizeAggregateFunctionsMoveSuffixData
|
||||
}
|
||||
};
|
||||
|
||||
struct FuseSumCountAggregates
|
||||
{
|
||||
std::vector<ASTFunction *> sums {};
|
||||
std::vector<ASTFunction *> counts {};
|
||||
std::vector<ASTFunction *> avgs {};
|
||||
|
||||
void addFuncNode(ASTFunction * func)
|
||||
{
|
||||
if (func->name == "sum")
|
||||
sums.push_back(func);
|
||||
else if (func->name == "count")
|
||||
counts.push_back(func);
|
||||
else
|
||||
{
|
||||
assert(func->name == "avg");
|
||||
avgs.push_back(func);
|
||||
}
|
||||
}
|
||||
|
||||
bool canBeFused() const
|
||||
{
|
||||
// Need at least two different kinds of functions to fuse.
|
||||
if (sums.empty() && counts.empty())
|
||||
return false;
|
||||
if (sums.empty() && avgs.empty())
|
||||
return false;
|
||||
if (counts.empty() && avgs.empty())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct FuseSumCountAggregatesVisitorData
|
||||
{
|
||||
using TypeToVisit = ASTFunction;
|
||||
|
||||
std::unordered_map<String, FuseSumCountAggregates> fuse_map;
|
||||
|
||||
void visit(ASTFunction & func, ASTPtr &)
|
||||
{
|
||||
if (func.name == "sum" || func.name == "avg" || func.name == "count")
|
||||
{
|
||||
if (func.arguments->children.empty())
|
||||
return;
|
||||
|
||||
// Probably we can extend it to match count() for non-nullable argument
|
||||
// to sum/avg with any other argument. Now we require strict match.
|
||||
const auto argument = func.arguments->children.at(0)->getColumnName();
|
||||
auto it = fuse_map.find(argument);
|
||||
if (it != fuse_map.end())
|
||||
{
|
||||
it->second.addFuncNode(&func);
|
||||
}
|
||||
else
|
||||
{
|
||||
FuseSumCountAggregates funcs{};
|
||||
funcs.addFuncNode(&func);
|
||||
fuse_map[argument] = funcs;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using CustomizeAggregateFunctionsOrNullVisitor = InDepthNodeVisitor<OneTypeMatcher<CustomizeAggregateFunctionsSuffixData>, true>;
|
||||
using CustomizeAggregateFunctionsMoveOrNullVisitor = InDepthNodeVisitor<OneTypeMatcher<CustomizeAggregateFunctionsMoveSuffixData>, true>;
|
||||
using FuseSumCountAggregatesVisitor = InDepthNodeVisitor<OneTypeMatcher<FuseSumCountAggregatesVisitorData>, true>;
|
||||
|
||||
/// Translate qualified names such as db.table.column, table.column, table_alias.column to names' normal form.
|
||||
/// Expand asterisks and qualified asterisks with column names.
|
||||
@ -200,6 +265,49 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query
|
||||
throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED);
|
||||
}
|
||||
|
||||
// Replaces one avg/sum/count function with an appropriate expression with
|
||||
// sumCount().
|
||||
void replaceWithSumCount(String column_name, ASTFunction & func)
|
||||
{
|
||||
auto func_base = makeASTFunction("sumCount", std::make_shared<ASTIdentifier>(column_name));
|
||||
auto exp_list = std::make_shared<ASTExpressionList>();
|
||||
if (func.name == "sum" || func.name == "count")
|
||||
{
|
||||
/// Rewrite "sum" to sumCount().1, rewrite "count" to sumCount().2
|
||||
UInt8 idx = (func.name == "sum" ? 1 : 2);
|
||||
func.name = "tupleElement";
|
||||
exp_list->children.push_back(func_base);
|
||||
exp_list->children.push_back(std::make_shared<ASTLiteral>(idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Rewrite "avg" to sumCount().1 / sumCount().2
|
||||
auto new_arg1 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(1)));
|
||||
auto new_arg2 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(2)));
|
||||
func.name = "divide";
|
||||
exp_list->children.push_back(new_arg1);
|
||||
exp_list->children.push_back(new_arg2);
|
||||
}
|
||||
func.arguments = exp_list;
|
||||
func.children.push_back(func.arguments);
|
||||
}
|
||||
|
||||
void fuseSumCountAggregates(std::unordered_map<String, FuseSumCountAggregates> & fuse_map)
|
||||
{
|
||||
for (auto & it : fuse_map)
|
||||
{
|
||||
if (it.second.canBeFused())
|
||||
{
|
||||
for (auto & func: it.second.sums)
|
||||
replaceWithSumCount(it.first, *func);
|
||||
for (auto & func: it.second.avgs)
|
||||
replaceWithSumCount(it.first, *func);
|
||||
for (auto & func: it.second.counts)
|
||||
replaceWithSumCount(it.first, *func);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool hasArrayJoin(const ASTPtr & ast)
|
||||
{
|
||||
if (const ASTFunction * function = ast->as<ASTFunction>())
|
||||
@ -910,7 +1018,18 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const NameSet &
|
||||
CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query);
|
||||
}
|
||||
|
||||
// Rewrite all aggregate functions to add -OrNull suffix to them
|
||||
// Try to fuse sum/avg/count with identical arguments to one sumCount call,
|
||||
// if we have at least two different functions. E.g. we will replace sum(x)
|
||||
// and count(x) with sumCount(x).1 and sumCount(x).2, and sumCount() will
|
||||
// be calculated only once because of CSE.
|
||||
if (settings.optimize_fuse_sum_count_avg)
|
||||
{
|
||||
FuseSumCountAggregatesVisitor::Data data;
|
||||
FuseSumCountAggregatesVisitor(data).visit(query);
|
||||
fuseSumCountAggregates(data.fuse_map);
|
||||
}
|
||||
|
||||
/// Rewrite all aggregate functions to add -OrNull suffix to them
|
||||
if (settings.aggregate_functions_null_for_empty)
|
||||
{
|
||||
CustomizeAggregateFunctionsOrNullVisitor::Data data_or_null{"OrNull"};
|
||||
|
33
tests/performance/fuse_sumcount.xml
Normal file
33
tests/performance/fuse_sumcount.xml
Normal file
@ -0,0 +1,33 @@
|
||||
<test>
|
||||
<!-- We test rewriting sum(), avg(), count() to a single call of sumCount() here.
|
||||
As a reference, we use the same queries with the optimization disabled.
|
||||
sum() has a highly optimized algorithm, so alone it will be faster than sumCount(),
|
||||
but when we add count() or avg(), the sumCount() should win.
|
||||
Also test GROUP BY with and without keys, because they might have different
|
||||
optimizations. -->
|
||||
<settings>
|
||||
<optimize_fuse_sum_count_avg>1</optimize_fuse_sum_count_avg>
|
||||
</settings>
|
||||
|
||||
<substitutions>
|
||||
<substitution>
|
||||
<name>key</name>
|
||||
<values>
|
||||
<value>1</value>
|
||||
<value>intHash32(number) % 1000</value>
|
||||
</values>
|
||||
</substitution>
|
||||
</substitutions>
|
||||
|
||||
<query>SELECT sum(number) FROM numbers(1000000000) FORMAT Null</query>
|
||||
<query>SELECT sum(number), count(number) FROM numbers(1000000000) FORMAT Null</query>
|
||||
<query>SELECT sum(number), count(number) FROM numbers(1000000000) SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
|
||||
<query>SELECT sum(number), avg(number), count(number) FROM numbers(1000000000) FORMAT Null</query>
|
||||
<query>SELECT sum(number), avg(number), count(number) FROM numbers(1000000000) SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
|
||||
|
||||
<query>SELECT sum(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 FORMAT Null</query>
|
||||
<query>SELECT sum(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 FORMAT Null</query>
|
||||
<query>SELECT sum(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
|
||||
<query>SELECT sum(number), avg(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 FORMAT Null</query>
|
||||
<query>SELECT sum(number), avg(number), count(number) FROM numbers(100000000) GROUP BY intHash32(number) % 1000 SETTINGS optimize_fuse_sum_count_avg = 0 FORMAT Null</query>
|
||||
</test>
|
@ -0,0 +1,12 @@
|
||||
210 230 20
|
||||
SELECT
|
||||
sum(a),
|
||||
sumCount(b).1,
|
||||
sumCount(b).2
|
||||
FROM fuse_tbl
|
||||
---------NOT trigger fuse--------
|
||||
210 11.5
|
||||
SELECT
|
||||
sum(a),
|
||||
avg(b)
|
||||
FROM fuse_tbl
|
11
tests/queries/0_stateless/01744_fuse_sum_count_aggregate.sql
Normal file
11
tests/queries/0_stateless/01744_fuse_sum_count_aggregate.sql
Normal file
@ -0,0 +1,11 @@
|
||||
DROP TABLE IF EXISTS fuse_tbl;
|
||||
CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log;
|
||||
INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20);
|
||||
|
||||
SET optimize_fuse_sum_count_avg = 1;
|
||||
SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
||||
EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
||||
SELECT '---------NOT trigger fuse--------';
|
||||
SELECT sum(a), avg(b) from fuse_tbl;
|
||||
EXPLAIN SYNTAX SELECT sum(a), avg(b) from fuse_tbl;
|
||||
DROP TABLE fuse_tbl;
|
Loading…
Reference in New Issue
Block a user