mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 08:32:02 +00:00
Rewriter fuse logical
This commit is contained in:
parent
4fe75ad168
commit
9dab51698f
@ -25,7 +25,6 @@ public:
|
|||||||
std::unordered_set<String> uniq_names {};
|
std::unordered_set<String> uniq_names {};
|
||||||
std::vector<const ASTFunction *> aggregates {};
|
std::vector<const ASTFunction *> aggregates {};
|
||||||
std::vector<const ASTFunction *> window_functions {};
|
std::vector<const ASTFunction *> window_functions {};
|
||||||
std::map<String, UInt8> fuse_sum_count_avg {};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child)
|
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child)
|
||||||
@ -73,14 +72,6 @@ private:
|
|||||||
|
|
||||||
data.uniq_names.insert(column_name);
|
data.uniq_names.insert(column_name);
|
||||||
data.aggregates.push_back(&node);
|
data.aggregates.push_back(&node);
|
||||||
if ((node.name == "sum" || node.name == "avg" || node.name == "count") && node.arguments->children.size() == 1)
|
|
||||||
{
|
|
||||||
const auto & argument = node.arguments->children.at(0);
|
|
||||||
ASTIdentifier * column = argument->as<ASTIdentifier>();
|
|
||||||
if (column)
|
|
||||||
data.fuse_sum_count_avg[column->name()] |= ((node.name == "sum") ? 0x1 : ((node.name == "count") ? 0x2 : 0x4));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
else if (node.is_window_function)
|
else if (node.is_window_function)
|
||||||
{
|
{
|
||||||
|
@ -182,29 +182,53 @@ struct CustomizeAggregateFunctionsMoveSuffixData
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FuseFunctions
|
||||||
|
{
|
||||||
|
std::vector<const ASTFunction *> sums {};
|
||||||
|
std::vector<const ASTFunction *> counts {};
|
||||||
|
std::vector<const ASTFunction *> avgs {};
|
||||||
|
|
||||||
|
void addFuncNode(const ASTFunction * func)
|
||||||
|
{
|
||||||
|
if (func->name == "sum")
|
||||||
|
sums.push_back(func);
|
||||||
|
else if (func->name == "count")
|
||||||
|
counts.push_back(func);
|
||||||
|
else if (func->name == "avg")
|
||||||
|
avgs.push_back(func);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool canBeFused() const
|
||||||
|
{
|
||||||
|
if (sums.empty() && counts.empty())
|
||||||
|
return false;
|
||||||
|
if (sums.empty() && avgs.empty())
|
||||||
|
return false;
|
||||||
|
if (counts.empty() && avgs.empty())
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct CustomizeFuseAggregateFunctionsData
|
struct CustomizeFuseAggregateFunctionsData
|
||||||
{
|
{
|
||||||
using TypeToVisit = ASTFunction;
|
using TypeToVisit = ASTFunction;
|
||||||
|
|
||||||
std::map<String, UInt8> fuse_info;
|
std::unordered_map<String, DB::FuseFunctions> fuse_map {};
|
||||||
|
|
||||||
static inline UInt8 bitCount(UInt8 n)
|
|
||||||
{
|
|
||||||
UInt8 c = 0;
|
|
||||||
for (c = 0; n; n >>= 1)
|
|
||||||
c += n & 1;
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
void visit(ASTFunction & func, ASTPtr &) const
|
void visit(ASTFunction & func, ASTPtr &) const
|
||||||
{
|
{
|
||||||
if (func.name == "sum" || func.name == "avg" || func.name == "count")
|
if (func.name == "sum" || func.name == "avg" || func.name == "count")
|
||||||
{
|
{
|
||||||
|
if (func.arguments->children.size() == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
ASTIdentifier * ident = func.arguments->children.at(0)->as<ASTIdentifier>();
|
ASTIdentifier * ident = func.arguments->children.at(0)->as<ASTIdentifier>();
|
||||||
if (!ident)
|
if (!ident)
|
||||||
return;
|
return;
|
||||||
auto column = fuse_info.find(ident->name());
|
|
||||||
if (column != fuse_info.end() && bitCount(column->second) > 1)
|
auto it = fuse_map.find(ident->name());
|
||||||
|
if (it != fuse_map.end() && it->second.canBeFused())
|
||||||
{
|
{
|
||||||
auto func_base = makeASTFunction("sumCount", func.arguments->children.at(0)->clone());
|
auto func_base = makeASTFunction("sumCount", func.arguments->children.at(0)->clone());
|
||||||
auto exp_list = std::make_shared<ASTExpressionList>();
|
auto exp_list = std::make_shared<ASTExpressionList>();
|
||||||
@ -253,6 +277,38 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query
|
|||||||
throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED);
|
throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gatherFuseFunctions(std::unordered_map<String, DB::FuseFunctions> &fuse_map, std::vector<const ASTFunction *> &aggregates)
|
||||||
|
{
|
||||||
|
for (auto & func : aggregates)
|
||||||
|
{
|
||||||
|
if ((func->name == "sum" || func->name == "avg" || func->name == "count") && func->arguments->children.size() == 1)
|
||||||
|
{
|
||||||
|
if (func->arguments->children.size() == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
ASTIdentifier * ident = func->arguments->children.at(0)->as<ASTIdentifier>();
|
||||||
|
if (!ident)
|
||||||
|
return;
|
||||||
|
|
||||||
|
ASTIdentifier * column = (func->arguments->children.at(0))->as<ASTIdentifier>();
|
||||||
|
if (!column)
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto it = fuse_map.find(column->name());
|
||||||
|
if (it != fuse_map.end())
|
||||||
|
{
|
||||||
|
it->second.addFuncNode(func);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
DB::FuseFunctions funcs{};
|
||||||
|
funcs.addFuncNode(func);
|
||||||
|
fuse_map.emplace(column->name(), funcs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool hasArrayJoin(const ASTPtr & ast)
|
bool hasArrayJoin(const ASTPtr & ast)
|
||||||
{
|
{
|
||||||
if (const ASTFunction * function = ast->as<ASTFunction>())
|
if (const ASTFunction * function = ast->as<ASTFunction>())
|
||||||
@ -969,19 +1025,21 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const Settings &
|
|||||||
CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query);
|
CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to fuse sum/avg/count with identical column(at least two functions exist) to sumCount()
|
/// Try to fuse sum/avg/count with identical column(at least two functions exist) to sumCount()
|
||||||
if (settings.optimize_fuse_sum_count_avg)
|
if (settings.optimize_fuse_sum_count_avg)
|
||||||
{
|
{
|
||||||
// Get statistics about sum/avg/count
|
/// Get statistics about sum/avg/count
|
||||||
GetAggregatesVisitor::Data data;
|
GetAggregatesVisitor::Data data;
|
||||||
GetAggregatesVisitor(data).visit(query);
|
GetAggregatesVisitor(data).visit(query);
|
||||||
|
std::unordered_map<String, DB::FuseFunctions> fuse_map;
|
||||||
|
gatherFuseFunctions(fuse_map, data.aggregates);
|
||||||
|
|
||||||
// Try to fuse
|
/// Try to fuse
|
||||||
CustomizeFuseAggregateFunctionsVisitor::Data data_fuse{.fuse_info = data.fuse_sum_count_avg};
|
CustomizeFuseAggregateFunctionsVisitor::Data data_fuse{.fuse_map = fuse_map};
|
||||||
CustomizeFuseAggregateFunctionsVisitor(data_fuse).visit(query);
|
CustomizeFuseAggregateFunctionsVisitor(data_fuse).visit(query);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite all aggregate functions to add -OrNull suffix to them
|
/// Rewrite all aggregate functions to add -OrNull suffix to them
|
||||||
if (settings.aggregate_functions_null_for_empty)
|
if (settings.aggregate_functions_null_for_empty)
|
||||||
{
|
{
|
||||||
CustomizeAggregateFunctionsOrNullVisitor::Data data_or_null{"OrNull"};
|
CustomizeAggregateFunctionsOrNullVisitor::Data data_or_null{"OrNull"};
|
||||||
|
@ -4,3 +4,9 @@ SELECT
|
|||||||
sumCount(b).1,
|
sumCount(b).1,
|
||||||
sumCount(b).2
|
sumCount(b).2
|
||||||
FROM fuse_tbl
|
FROM fuse_tbl
|
||||||
|
---------NOT trigger fuse--------
|
||||||
|
210 11.5
|
||||||
|
SELECT
|
||||||
|
sum(a),
|
||||||
|
avg(b)
|
||||||
|
FROM fuse_tbl
|
||||||
|
@ -2,7 +2,10 @@ DROP TABLE IF EXISTS fuse_tbl;
|
|||||||
CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log;
|
CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log;
|
||||||
INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20);
|
INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20);
|
||||||
|
|
||||||
set optimize_fuse_sum_count_avg = 1;
|
SET optimize_fuse_sum_count_avg = 1;
|
||||||
SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
||||||
EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
||||||
|
SELECT '---------NOT trigger fuse--------';
|
||||||
|
SELECT sum(a), avg(b) from fuse_tbl;
|
||||||
|
EXPLAIN SYNTAX SELECT sum(a), avg(b) from fuse_tbl;
|
||||||
DROP TABLE fuse_tbl;
|
DROP TABLE fuse_tbl;
|
||||||
|
Loading…
Reference in New Issue
Block a user