Rework expressions with window functions

This commit is contained in:
Dmitry Novik 2022-06-16 13:29:56 +00:00
parent eddd4ecaeb
commit 0663f07e67
12 changed files with 242 additions and 119 deletions

View File

@ -1,4 +1,5 @@
#include <memory>
#include "Common/logger_useful.h"
#include <Common/quoteString.h>
#include <Common/typeid_cast.h>
#include <Columns/ColumnArray.h>
@ -48,6 +49,7 @@
#include <Interpreters/DatabaseAndTableWithAlias.h>
#include <Interpreters/IdentifierSemantic.h>
#include <Interpreters/UserDefinedExecutableFunctionFactory.h>
#include <Poco/Logger.h>
namespace DB
@ -941,28 +943,14 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
bool is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(node.name);
if (data.window_function_called.has_value())
{
bool subtree_contains_window_call = false;
for (const auto & arg : node.arguments->children)
{
data.window_function_called = false;
visit(arg, data);
subtree_contains_window_call = subtree_contains_window_call || data.window_function_called.value();
}
data.window_function_called = subtree_contains_window_call
|| (!subtree_contains_window_call && is_aggregate_function);
if (data.window_function_called.value())
return;
}
else if (node.is_window_function)
LOG_DEBUG(&Poco::Logger::get("ActionVisitor"), "Processing function {}, with compute_after_window_functions={}", node.getColumnName(), node.compute_after_window_functions);
if (node.is_window_function)
{
// Also add columns from PARTITION BY and ORDER BY of window functions.
if (node.window_definition)
{
visit(node.window_definition, data);
}
data.window_function_called.emplace();
// Also manually add columns for arguments of the window function itself.
// ActionVisitor is written in such a way that this method must itself
// descend into all needed function children. Window functions can't have
@ -975,14 +963,44 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
{
visit(arg, data);
}
data.window_function_called.reset();
// Don't need to do anything more for window functions here -- the
// resulting column is added in ExpressionAnalyzer, similar to the
// aggregate functions.
if (data.window_function_called.has_value())
data.window_function_called = true;
return;
}
else if (node.compute_after_window_functions)
{
data.window_function_called.emplace();
bool subtree_contains_window_call = false;
for (const auto & arg : node.arguments->children)
{
LOG_DEBUG(&Poco::Logger::get("ActionVisitor"), "Processing arg: {}", arg->getColumnName());
data.window_function_called = false;
visit(arg, data);
LOG_DEBUG(&Poco::Logger::get("ActionVisitor"), "Processed arg: {}, result: {}", arg->getColumnName(), data.window_function_called.value());
subtree_contains_window_call = subtree_contains_window_call || data.window_function_called.value();
}
// assert(subtree_contains_window_call);
data.window_function_called.reset();
if (!data.build_expression_with_window_functions)
return;
}
else if (data.window_function_called.has_value())
{
bool subtree_contains_window_call = false;
for (const auto & arg : node.arguments->children)
{
data.window_function_called = false;
visit(arg, data);
subtree_contains_window_call = subtree_contains_window_call || data.window_function_called.value();
}
data.window_function_called = subtree_contains_window_call;
if (subtree_contains_window_call && !data.build_expression_with_window_functions)
return;
}
// An aggregate function can also be calculated as a window function, but we
// checked for it above, so no need to do anything more.

View File

@ -879,8 +879,48 @@ void ExpressionAnalyzer::makeWindowDescriptions(ActionsDAGPtr actions)
// Find the window corresponding to this function. It may be either
// referenced by name and previously defined in WINDOW clause, or it
// may be defined inline.
WindowFunctionDescription window_function;
window_function.function_node = function_node;
window_function.column_name
= window_function.function_node->getColumnName();
window_function.function_parameters
= window_function.function_node->parameters
? getAggregateFunctionParametersArray(
window_function.function_node->parameters, "", getContext())
: Array();
// Requiring a constant reference to a shared pointer to non-const AST
// doesn't really look sane, but the visitor does indeed require it.
// Hence we clone the node (not very sane either, I know).
getRootActionsNoMakeSet(window_function.function_node->clone(), actions);
const ASTs & arguments
= window_function.function_node->arguments->children;
window_function.argument_types.resize(arguments.size());
window_function.argument_names.resize(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
{
const std::string & name = arguments[i]->getColumnName();
const auto * node = actions->tryFindInIndex(name);
if (!node)
{
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER,
"Unknown identifier '{}' in window function '{}'",
name, window_function.function_node->formatForErrorMessage());
}
window_function.argument_types[i] = node->result_type;
window_function.argument_names[i] = name;
}
AggregateFunctionProperties properties;
window_function.aggregate_function
= AggregateFunctionFactory::instance().get(
window_function.function_node->name,
window_function.argument_types,
window_function.function_parameters, properties);
WindowDescription * window_description;
if (!function_node->window_name.empty())
{
auto it = window_descriptions.find(function_node->window_name);
@ -892,8 +932,7 @@ void ExpressionAnalyzer::makeWindowDescriptions(ActionsDAGPtr actions)
function_node->formatForErrorMessage());
}
window_description = &it->second;
// it->second.window_functions.push_back(window_function);
it->second.window_functions.push_back(window_function);
}
else
{
@ -913,66 +952,7 @@ void ExpressionAnalyzer::makeWindowDescriptions(ActionsDAGPtr actions)
== desc.full_sort_description);
}
window_description = &it->second;
// it->second.window_functions.push_back(window_function);
}
WindowFunctionList functions;
WindowFunctionsUtils::collectWindowFunctionsFromExpression(function_node, functions);
if (functions.size() != 1 || functions.front() != function_node)
{
window_description->expressions_with_window_functions.push_back(function_node);
if (function_node->window_definition)
{
getRootActionsNoMakeSet(function_node->window_definition, actions);
}
}
for (const auto * function : functions)
{
WindowFunctionDescription window_function;
window_function.function_node = function;
window_function.column_name
= window_function.function_node->getColumnName();
window_function.function_parameters
= window_function.function_node->parameters
? getAggregateFunctionParametersArray(
window_function.function_node->parameters, "", getContext())
: Array();
// Requiring a constant reference to a shared pointer to non-const AST
// doesn't really look sane, but the visitor does indeed require it.
// Hence we clone the node (not very sane either, I know).
getRootActionsNoMakeSet(window_function.function_node->clone(), actions);
const ASTs & arguments
= window_function.function_node->arguments->children;
window_function.argument_types.resize(arguments.size());
window_function.argument_names.resize(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
{
const std::string & name = arguments[i]->getColumnName();
const auto * node = actions->tryFindInIndex(name);
if (!node)
{
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER,
"Unknown identifier '{}' in window function '{}'",
name, window_function.function_node->formatForErrorMessage());
}
window_function.argument_types[i] = node->result_type;
window_function.argument_names[i] = name;
}
AggregateFunctionProperties properties;
window_function.aggregate_function
= AggregateFunctionFactory::instance().get(
window_function.function_node->name,
window_function.argument_types,
window_function.function_parameters, properties);
window_description->window_functions.push_back(window_function);
it->second.window_functions.push_back(window_function);
}
}
@ -1433,12 +1413,9 @@ void SelectQueryExpressionAnalyzer::appendWindowFunctionsArguments(
void SelectQueryExpressionAnalyzer::appendExpressionsAfterWindowFunctions(ExpressionActionsChain & chain, bool /* only_types */)
{
ExpressionActionsChain::Step & step = chain.lastStep(columns_after_window);
for (const auto & [_, w] : window_descriptions)
for (const auto & expression : syntax->expressions_with_window_function)
{
for (const auto & expression : w.expressions_with_window_functions)
{
getRootActionsForWindowFunctions(expression->clone(), true, step.actions());
}
getRootActionsForWindowFunctions(expression->clone(), true, step.actions());
}
}
@ -1469,7 +1446,7 @@ void SelectQueryExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain,
{
if (const auto * function = typeid_cast<const ASTFunction *>(child.get());
function
&& function->is_window_function)
&& (function->is_window_function || function->compute_after_window_functions))
{
// Skip window function columns here -- they are calculated after
// other SELECT expressions by a special step.

View File

@ -0,0 +1,84 @@
#include <Interpreters/GetAggregatesVisitor.h>
namespace DB
{
struct WindowExpressionsCollectorChildInfo
{
void update(const WindowExpressionsCollectorChildInfo & other)
{
window_function_in_subtree = window_function_in_subtree || other.window_function_in_subtree;
}
bool window_function_in_subtree = false;
};
struct WindowExpressionsCollectorMatcher
{
using ChildInfo = WindowExpressionsCollectorChildInfo;
static bool needVisitChild(ASTPtr & node, const ASTPtr & child)
{
if (child->as<ASTSubquery>() || child->as<ASTSelectQuery>())
return false;
if (auto * select = node->as<ASTSelectQuery>())
{
// We don't analysis WITH statement because it might contain useless aggregates
if (child == select->with())
return false;
}
// We procces every expression manually
if (auto * func = node->as<ASTFunction>())
return false;
return true;
}
WindowExpressionsCollectorChildInfo visitNode(
ASTPtr & ast,
const ASTPtr & parent,
WindowExpressionsCollectorChildInfo const &)
{
return visitNode(ast, parent);
}
WindowExpressionsCollectorChildInfo visitNode(
ASTPtr & ast,
const ASTPtr & parent)
{
if (auto * func = ast->as<ASTFunction>())
{
if (func->is_window_function)
return { .window_function_in_subtree = true };
bool window_function_in_subtree = false;
for (auto & arg : func->arguments->children)
{
auto subtree_result = visitNode(arg, ast);
window_function_in_subtree = window_function_in_subtree || subtree_result.window_function_in_subtree;
}
// We mark functions only on the top of AST
if ((!parent || !parent->as<ASTFunction>()) && window_function_in_subtree)
{
expressions_with_window_functions.push_back(func);
func->compute_after_window_functions = true;
}
return { .window_function_in_subtree = window_function_in_subtree };
}
return {};
}
std::vector<const ASTFunction *> expressions_with_window_functions {};
};
using WindowExpressionsCollectorVisitor = InDepthNodeVisitorWithChildInfo<WindowExpressionsCollectorMatcher>;
std::vector<const ASTFunction *> getExpressionsWithWindowFunctions(ASTPtr & ast)
{
WindowExpressionsCollectorVisitor visitor;
visitor.visit(ast);
return std::move(visitor.expressions_with_window_functions);
}
}

View File

@ -42,7 +42,7 @@ public:
}
if (auto * func = node->as<ASTFunction>())
{
if (isAggregateFunction(*func) || func->is_window_function)
if (isAggregateFunction(*func))
{
return false;
}
@ -82,9 +82,6 @@ private:
throw Exception("Window function " + node.getColumnName() + " is found " + String(data.assert_no_windows) + " in query",
ErrorCodes::ILLEGAL_AGGREGATION);
if (node.window_definition)
visit(node.window_definition, data);
String column_name = node.getColumnName();
if (data.uniq_names.count(column_name))
return;
@ -119,4 +116,6 @@ inline void assertNoAggregates(const ASTPtr & ast, const char * description)
GetAggregatesVisitor(data).visit(ast);
}
std::vector<const ASTFunction *> getExpressionsWithWindowFunctions(ASTPtr & ast);
}

View File

@ -95,4 +95,33 @@ public:
template <typename Data, NeedChild::Condition need_child = NeedChild::all>
using ConstOneTypeMatcher = OneTypeMatcher<Data, need_child, const ASTPtr>;
template <typename Visitor, typename T = ASTPtr>
struct InDepthNodeVisitorWithChildInfo : Visitor
{
using ChildInfo = typename Visitor::ChildInfo;
ChildInfo visit(T & ast, const T & parent = {})
{
ChildInfo all_children_info;
for (auto & child : ast->children)
{
if (Visitor::needVisitChild(ast, child))
{
ChildInfo child_info = visit(child, ast);
all_children_info.update(child_info);
}
}
try
{
return Visitor::visitNode(ast, parent, all_children_info);
}
catch (Exception & e)
{
e.addMessage("While processing {}", ast->formatForErrorMessage());
throw;
}
}
};
}

View File

@ -1248,6 +1248,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
result.aggregates = getAggregates(query, *select_query);
result.window_function_asts = getWindowFunctions(query, *select_query);
result.expressions_with_window_function = getExpressionsWithWindowFunctions(query);
result.collectUsedColumns(query, true);
result.required_source_columns_before_expanding_alias_columns = result.required_source_columns.getNames();
@ -1271,6 +1272,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
{
result.aggregates = getAggregates(query, *select_query);
result.window_function_asts = getWindowFunctions(query, *select_query);
result.expressions_with_window_function = getExpressionsWithWindowFunctions(query);
result.collectUsedColumns(query, true);
}
}

View File

@ -44,6 +44,8 @@ struct TreeRewriterResult
std::vector<const ASTFunction *> window_function_asts;
std::vector<const ASTFunction *> expressions_with_window_function;
/// Which column is needed to be ARRAY-JOIN'ed to get the specified.
/// For example, for `SELECT s.v ... ARRAY JOIN a AS s` will get "s.v" -> "a.v".
NameToNameMap array_join_result_to_source;

View File

@ -22,6 +22,8 @@ public:
bool is_window_function = false;
bool compute_after_window_functions = false;
// We have to make these fields ASTPtr because this is what the visitors
// expect. Some of them take const ASTPtr & (makes no sense), and some
// take ASTPtr & and modify it. I don't understand how the latter is

View File

@ -1034,6 +1034,33 @@ bool ParserFunction::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
function_node->children.push_back(function_node->parameters);
}
ParserKeyword filter("FILTER");
ParserKeyword over("OVER");
if (filter.ignore(pos, expected))
{
// We are slightly breaking the parser interface by parsing the window
// definition into an existing ASTFunction. Normally it would take a
// reference to ASTPtr and assign it the new node. We only have a pointer
// of a different type, hence this workaround with a temporary pointer.
ASTPtr function_node_as_iast = function_node;
ParserFilterClause filter_parser;
if (!filter_parser.parse(pos, function_node_as_iast, expected))
return false;
}
if (over.ignore(pos, expected))
{
function_node->is_window_function = true;
ASTPtr function_node_as_iast = function_node;
ParserWindowReference window_reference;
if (!window_reference.parse(pos, function_node_as_iast, expected))
return false;
}
node = function_node;
return true;
}

View File

@ -533,31 +533,6 @@ bool ParserTernaryOperatorExpression::parseImpl(Pos & pos, ASTPtr & node, Expect
node = function;
}
ParserKeyword filter("FILTER");
ParserKeyword over("OVER");
if (filter.ignore(pos, expected))
{
// We are slightly breaking the parser interface by parsing the window
// definition into an existing ASTFunction. Normally it would take a
// reference to ASTPtr and assign it the new node. We only have a pointer
// of a different type, hence this workaround with a temporary pointer.
ParserFilterClause filter_parser;
if (!filter_parser.parse(pos, node, expected))
return false;
}
if (over.ignore(pos, expected))
{
auto function_node = typeid_cast<std::shared_ptr<ASTFunction>>(node);
function_node->is_window_function = true;
ParserWindowReference window_reference;
if (!window_reference.parse(pos, node, expected))
return false;
}
return true;
}

View File

@ -1,5 +1,9 @@
-- { echoOn }
SELECT number, sum(number) + 1 OVER (PARTITION BY number % 10)
-- SELECT number, sum(number) + 1 OVER (PARTITION BY (number % 10))
-- FROM numbers(100)
-- ORDER BY number; -- { clientError SYNTAX_ERROR }
SELECT number, 1 + sum(number) OVER (PARTITION BY number % 10)
FROM numbers(100)
ORDER BY number;
0 451
@ -118,7 +122,7 @@ ORDER BY x;
541
SELECT
number,
sum(number) / count() OVER (PARTITION BY number % 10),
sum(number) OVER (PARTITION BY number % 10) / count() OVER (PARTITION BY number % 10),
avg(number) OVER (PARTITION BY number % 10)
FROM numbers(100)
ORDER BY number ASC;

View File

@ -1,5 +1,9 @@
-- { echoOn }
SELECT number, sum(number) + 1 OVER (PARTITION BY number % 10)
-- SELECT number, sum(number) + 1 OVER (PARTITION BY (number % 10))
-- FROM numbers(100)
-- ORDER BY number; -- { clientError SYNTAX_ERROR }
SELECT number, 1 + sum(number) OVER (PARTITION BY number % 10)
FROM numbers(100)
ORDER BY number;
@ -10,7 +14,7 @@ ORDER BY x;
SELECT
number,
sum(number) / count() OVER (PARTITION BY number % 10),
sum(number) OVER (PARTITION BY number % 10) / count() OVER (PARTITION BY number % 10),
avg(number) OVER (PARTITION BY number % 10)
FROM numbers(100)
ORDER BY number ASC;