expression analyzer part

This commit is contained in:
Alexander Kuzmenkov 2020-12-09 14:14:40 +03:00
parent 8ee86e35d2
commit eb0c817bf2
11 changed files with 377 additions and 18 deletions

View File

@ -58,6 +58,9 @@ static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types)
AggregateFunctionPtr AggregateFunctionFactory::get(
const String & name, const DataTypes & argument_types, const Array & parameters, AggregateFunctionProperties & out_properties) const
{
fmt::print(stderr, "get aggregate function {} at \n{}\n",
name, StackTrace().toString());
auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
/// If one of the types is Nullable, we apply aggregate function combinator "Null".

View File

@ -738,6 +738,14 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
if (AggregateFunctionFactory::instance().isAggregateFunctionName(node.name))
return;
/// FIXME need proper grammar for window functions. For now, ignore it --
/// the resulting column is added in ExpressionAnalyzer, similar to the
/// aggregate functions.
if (node.name == "window")
{
return;
}
FunctionOverloadResolverPtr function_builder;
try
{

View File

@ -1,13 +1,18 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/IDataType.h>
#include <Core/ColumnNumbers.h>
#include <Core/Names.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Core/SortDescription.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTFunction;
struct AggregateDescription
{
AggregateFunctionPtr function;
@ -21,4 +26,32 @@ struct AggregateDescription
using AggregateDescriptions = std::vector<AggregateDescription>;
struct WindowFunctionDescription
{
std::string window_name;
std::string column_name;
const IAST * wrapper_node;
const ASTFunction * function_node;
AggregateFunctionPtr aggregate_function;
Array function_parameters;
DataTypes argument_types;
Names argument_names;
std::string dump() const;
};
struct WindowDescription
{
std::string window_name;
// Always ASC for now.
std::vector<std::string> partition_by;
std::vector<std::string> order_by;
// No frame info as of yet.
};
using WindowFunctionDescriptions = std::vector<WindowFunctionDescription>;
using WindowDescriptions = std::unordered_map<std::string, WindowDescription>;
}

View File

@ -1,5 +1,7 @@
#include <Core/Block.h>
#include <iostream>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
@ -57,12 +59,13 @@ using LogAST = DebugASTLog<false>; /// set to true to enable logs
namespace ErrorCodes
{
extern const int UNKNOWN_TYPE_OF_AST_NODE;
extern const int UNKNOWN_IDENTIFIER;
extern const int BAD_ARGUMENTS;
extern const int ILLEGAL_PREWHERE;
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER;
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_IDENTIFIER;
extern const int UNKNOWN_TYPE_OF_AST_NODE;
}
namespace
@ -277,6 +280,8 @@ void ExpressionAnalyzer::analyzeAggregation()
{
aggregated_columns = temp_actions->getNamesAndTypesList();
}
has_window = makeWindowDescriptions(temp_actions);
}
@ -438,7 +443,11 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ActionsDAGPtr & actions)
auto it = index.find(name);
if (it == index.end())
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, "Unknown identifier (in aggregate function '{}'): {}", node->name, name);
{
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER,
"Unknown identifier '{}' in aggregate function '{}'",
name, node->formatForErrorMessage());
}
types[i] = (*it)->result_type;
aggregate.argument_names[i] = name;
@ -455,6 +464,125 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ActionsDAGPtr & actions)
}
bool ExpressionAnalyzer::makeWindowDescriptions(ActionsDAGPtr & actions)
{
for (const ASTFunction * wrapper_node : windowFunctions())
{
fmt::print(stderr, "window function ast: {}\n", wrapper_node->dumpTree());
// Not sure why NoMakeSet, copied from aggregate functions.
getRootActionsNoMakeSet(wrapper_node->arguments, true /* no subqueries */,
actions);
// FIXME not thread-safe, should use a per-query counter.
static int window_index = 1;
WindowDescription window_description;
window_description.window_name = fmt::format("window_{}", window_index++);
const auto * elist = wrapper_node->arguments
? wrapper_node->arguments->as<const ASTExpressionList>()
: nullptr;
if (elist)
{
if (elist->children.size() >= 2)
{
const auto partition_by_ast = elist->children[1];
fmt::print(stderr, "partition by ast {}\n",
partition_by_ast->dumpTree());
if (const auto * as_tuple = partition_by_ast->as<ASTFunction>();
as_tuple
&& as_tuple->name == "tuple"
&& as_tuple->arguments)
{
// untuple it
for (const auto & element_ast
: as_tuple->arguments->children)
{
const auto * with_alias = dynamic_cast<
const ASTWithAlias *>(element_ast.get());
if (!with_alias)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"(1) Expected column in PARTITION BY"
" for window '{}', got '{}'",
window_description.window_name,
element_ast->formatForErrorMessage());
}
window_description.partition_by.push_back(
with_alias->getColumnName());
}
}
else if (const auto * with_alias
= dynamic_cast<const ASTWithAlias *>(partition_by_ast.get()))
{
window_description.partition_by.push_back(
with_alias->getColumnName());
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"(2) Expected tuple or column in PARTITION BY"
" for window '{}', got '{}'",
window_description.window_name,
partition_by_ast->formatForErrorMessage());
}
}
}
WindowFunctionDescription window_function;
window_function.window_name = window_description.window_name;
window_function.wrapper_node = wrapper_node;
window_function.function_node
= &elist->children.at(0)->as<ASTFunction &>();
window_function.column_name
= window_function.function_node->getColumnName();
window_function.function_parameters
= window_function.function_node->parameters
? getAggregateFunctionParametersArray(
window_function.function_node->parameters)
: Array();
const ASTs & arguments
= window_function.function_node->arguments->children;
window_function.argument_types.resize(arguments.size());
window_function.argument_names.resize(arguments.size());
const auto & index = actions->getIndex();
for (size_t i = 0; i < arguments.size(); ++i)
{
const std::string & name = arguments[i]->getColumnName();
auto it = index.find(name);
if (it == index.end())
{
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER,
"Unknown identifier '{}' in window function '{}'",
name, window_function.function_node->formatForErrorMessage());
}
window_function.argument_types[i] = (*it)->result_type;
window_function.argument_names[i] = name;
}
AggregateFunctionProperties properties;
window_function.aggregate_function
= AggregateFunctionFactory::instance().get(
window_function.function_node->name,
window_function.argument_types,
window_function.function_parameters, properties);
window_descriptions.insert({window_description.window_name,
window_description});
window_functions.push_back(window_function);
fmt::print(stderr, "{}\n", window_function.dump());
}
return !windowFunctions().empty();
}
const ASTSelectQuery * ExpressionAnalyzer::getSelectQuery() const
{
const auto * select_query = query->as<ASTSelectQuery>();
@ -822,6 +950,37 @@ void SelectQueryExpressionAnalyzer::appendAggregateFunctionsArguments(Expression
for (const ASTFunction * node : data.aggregates)
for (auto & argument : node->arguments->children)
getRootActions(argument, only_types, step.actions());
fmt::print(stderr, "actions after appendAggregateFunctionsArguments: \n{} at \n{}\n", chain.dumpChain(), StackTrace().toString());
}
void SelectQueryExpressionAnalyzer::appendWindowFunctionsArguments(
ExpressionActionsChain & chain, bool only_types)
{
fmt::print(stderr, "actions before window: {}\n", chain.dumpChain());
const auto * select_query = getSelectQuery();
ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
/*
for (const auto & desc : aggregate_descriptions)
for (const auto & name : desc.argument_names)
step.required_output.emplace_back(name);
*/
/// Collect aggregates removing duplicates by node.getColumnName()
/// It's not clear why we recollect aggregates (for query parts) while we're able to use previously collected ones (for entire query)
/// @note The original recollection logic didn't remove duplicates.
GetAggregatesVisitor::Data data;
GetAggregatesVisitor(data).visit(select_query->select());
/// TODO: data.aggregates -> aggregates()
for (const ASTFunction * node : data.window_functions)
for (auto & argument : node->arguments->children)
getRootActions(argument, only_types, step.actions());
fmt::print(stderr, "actions after window: {}\n", chain.dumpChain());
}
bool SelectQueryExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_types)
@ -831,6 +990,8 @@ bool SelectQueryExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain,
if (!select_query->having())
return false;
fmt::print(stderr, "has having\n");
ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
getRootActionsForHaving(select_query->having(), only_types, step.actions());
@ -848,7 +1009,16 @@ void SelectQueryExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain,
getRootActions(select_query->select(), only_types, step.actions());
for (const auto & child : select_query->select()->children)
{
/// FIXME add proper grammar for window functions
if (const auto * as_function = child->as<ASTFunction>();
as_function && as_function->name == "window")
{
continue;
}
step.required_output.push_back(child->getColumnName());
}
}
ActionsDAGPtr SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order,
@ -943,6 +1113,13 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio
ASTs asts = select_query->select()->children;
for (const auto & ast : asts)
{
/// FIXME add proper grammar for window functions
if (const auto * as_function = ast->as<ASTFunction>();
as_function && as_function->name == "window")
{
continue;
}
String result_name = ast->getAliasOrColumnName();
if (required_result_columns.empty() || required_result_columns.count(result_name))
{
@ -1069,6 +1246,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
: first_stage(first_stage_)
, second_stage(second_stage_)
, need_aggregate(query_analyzer.hasAggregation())
, has_window(query_analyzer.hasWindow())
{
/// first_stage: Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
/// second_stage: Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
@ -1181,6 +1359,8 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
chain.addStep();
}
fmt::print(stderr, "chain before aggregate: {}\n", chain.dumpChain());
if (need_aggregate)
{
/// TODO correct conditions
@ -1189,18 +1369,29 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
&& storage && query.groupBy();
query_analyzer.appendGroupBy(chain, only_types || !first_stage, optimize_aggregation_in_order, group_by_elements_actions);
fmt::print(stderr, "chain after appendGroupBy: {}\n", chain.dumpChain());
query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !first_stage);
fmt::print(stderr, "chain after appendAggregateFunctionsArguments: {}\n", chain.dumpChain());
before_aggregation = chain.getLastActions();
finalize_chain(chain);
fmt::print(stderr, "chain after finalize_chain: {}\n", chain.dumpChain());
if (query_analyzer.appendHaving(chain, only_types || !second_stage))
{
fmt::print(stderr, "chain after appendHaving: {}\n", chain.dumpChain());
before_having = chain.getLastActions();
chain.addStep();
}
}
fmt::print(stderr, "chain after aggregate: {}\n", chain.dumpChain());
bool join_allow_read_in_order = true;
if (hasJoin())
{
@ -1216,8 +1407,18 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
&& !query.final()
&& join_allow_read_in_order;
if (has_window)
{
query_analyzer.appendWindowFunctionsArguments(chain, only_types || !first_stage);
before_window = chain.getLastActions();
finalize_chain(chain);
}
/// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers.
query_analyzer.appendSelect(chain, only_types || (need_aggregate ? !second_stage : !first_stage));
fmt::print(stderr, "chain after select: {}\n", chain.dumpChain());
selected_columns = chain.getLastStep().required_output;
has_order_by = query.orderBy() != nullptr;
before_order_and_select = query_analyzer.appendOrderBy(
@ -1226,6 +1427,10 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
optimize_read_in_order,
order_by_elements_actions);
fmt::print(stderr, "chain after order by: {}\n", chain.dumpChain());
//if (h
if (query_analyzer.appendLimitBy(chain, only_types || !second_stage))
{
before_limit_by = chain.getLastActions();
@ -1242,6 +1447,8 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
checkActions();
fmt::print(stderr, "ExpressionAnalysisResult created at \n{}\n",
StackTrace().toString());
fmt::print(stderr, "ExpressionAnalysisResult: \n{}\n", dump());
}
@ -1320,9 +1527,9 @@ std::string ExpressionAnalysisResult::dump() const
{
std::stringstream ss;
ss << "ExpressionAnalysisResult\n";
ss << "need_aggregate " << need_aggregate << "\n";
ss << "has_order_by " << has_order_by << "\n";
ss << "has_window " << has_window << "\n";
if (before_array_join)
{
@ -1364,6 +1571,11 @@ std::string ExpressionAnalysisResult::dump() const
ss << "before_having " << before_having->dumpDAG() << "\n";
}
if (before_window)
{
ss << "before_window " << before_window->dumpDAG() << "\n";
}
if (before_order_and_select)
{
ss << "before_order_and_select " << before_order_and_select->dumpDAG() << "\n";
@ -1382,4 +1594,19 @@ std::string ExpressionAnalysisResult::dump() const
return ss.str();
}
std::string WindowFunctionDescription::dump() const
{
std::stringstream ss;
ss << "window function '" << column_name << "' over '" << window_name <<"\n";
ss << "wrapper node " << wrapper_node->dumpTree() << "\n";
ss << "function node " << function_node->dumpTree() << "\n";
ss << "aggregate function '" << aggregate_function->getName() << "'\n";
if (function_parameters.size())
{
ss << "parameters " << toString(function_parameters) << "\n";
}
return ss.str();
}
}

View File

@ -60,6 +60,11 @@ struct ExpressionAnalyzerData
NamesAndTypesList aggregation_keys;
AggregateDescriptions aggregate_descriptions;
bool has_window = false;
WindowDescriptions window_descriptions;
WindowFunctionDescriptions window_functions;
NamesAndTypesList window_columns;
bool has_global_subqueries = false;
/// All new temporary tables obtained by performing the GLOBAL IN/JOIN subqueries.
@ -139,6 +144,7 @@ protected:
const TableJoin & analyzedJoin() const { return *syntax->analyzed_join; }
const NamesAndTypesList & sourceColumns() const { return syntax->required_source_columns; }
const std::vector<const ASTFunction *> & aggregates() const { return syntax->aggregates; }
const std::vector<const ASTFunction *> & windowFunctions() const { return syntax->window_functions; }
/// Find global subqueries in the GLOBAL IN/JOIN sections. Fills in external_tables.
void initGlobalSubqueriesAndExternalTables(bool do_global);
@ -162,6 +168,8 @@ protected:
void analyzeAggregation();
bool makeAggregateDescriptions(ActionsDAGPtr & actions);
bool makeWindowDescriptions(ActionsDAGPtr & actions);
const ASTSelectQuery * getSelectQuery() const;
bool isRemoteStorage() const { return syntax->is_remote_storage; }
@ -181,6 +189,7 @@ struct ExpressionAnalysisResult
bool need_aggregate = false;
bool has_order_by = false;
bool has_window = false;
bool remove_where_filter = false;
bool optimize_read_in_order = false;
@ -194,6 +203,7 @@ struct ExpressionAnalysisResult
ActionsDAGPtr before_where;
ActionsDAGPtr before_aggregation;
ActionsDAGPtr before_having;
ActionsDAGPtr before_window;
ActionsDAGPtr before_order_and_select;
ActionsDAGPtr before_limit_by;
ActionsDAGPtr final_projection;
@ -261,6 +271,7 @@ public:
/// Does the expression have aggregate functions or a GROUP BY or HAVING section.
bool hasAggregation() const { return has_aggregation; }
bool hasWindow() const { return has_window; }
bool hasGlobalSubqueries() { return has_global_subqueries; }
bool hasTableJoin() const { return syntax->ast_join; }
@ -331,6 +342,7 @@ private:
bool appendWhere(ExpressionActionsChain & chain, bool only_types);
bool appendGroupBy(ExpressionActionsChain & chain, bool only_types, bool optimize_aggregation_in_order, ManyExpressionActions &);
void appendAggregateFunctionsArguments(ExpressionActionsChain & chain, bool only_types);
void appendWindowFunctionsArguments(ExpressionActionsChain & chain, bool only_types);
/// After aggregation:
bool appendHaving(ExpressionActionsChain & chain, bool only_types);

View File

@ -19,8 +19,10 @@ public:
struct Data
{
const char * assert_no_aggregates = nullptr;
const char * assert_no_windows = nullptr;
std::unordered_set<String> uniq_names;
std::vector<const ASTFunction *> aggregates;
std::vector<const ASTFunction *> window_functions;
};
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child)
@ -28,8 +30,14 @@ public:
if (child->as<ASTSubquery>() || child->as<ASTSelectQuery>())
return false;
if (auto * func = node->as<ASTFunction>())
if (isAggregateFunction(func->name))
{
if (isAggregateFunction(func->name)
// FIXME temporary hack while we don't have grammar
|| func->name == "window")
{
return false;
}
}
return true;
}
@ -42,19 +50,32 @@ public:
private:
static void visit(const ASTFunction & node, const ASTPtr &, Data & data)
{
if (!isAggregateFunction(node.name))
return;
if (isAggregateFunction(node.name))
{
if (data.assert_no_aggregates)
throw Exception("Aggregate function " + node.getColumnName() + " is found " + String(data.assert_no_aggregates) + " in query",
ErrorCodes::ILLEGAL_AGGREGATION);
if (data.assert_no_aggregates)
throw Exception("Aggregate function " + node.getColumnName() + " is found " + String(data.assert_no_aggregates) + " in query",
ErrorCodes::ILLEGAL_AGGREGATION);
String column_name = node.getColumnName();
if (data.uniq_names.count(column_name))
return;
String column_name = node.getColumnName();
if (data.uniq_names.count(column_name))
return;
data.uniq_names.insert(column_name);
data.aggregates.push_back(&node);
}
else if (node.name == "window")
{
if (data.assert_no_windows)
throw Exception("Window function " + node.getColumnName() + " is found " + String(data.assert_no_windows) + " in query",
ErrorCodes::ILLEGAL_AGGREGATION);
data.uniq_names.insert(column_name);
data.aggregates.push_back(&node);
String column_name = node.getColumnName();
if (data.uniq_names.count(column_name))
return;
data.uniq_names.insert(column_name);
data.window_functions.push_back(&node);
}
}
static bool isAggregateFunction(const String & name)
@ -66,9 +87,15 @@ private:
using GetAggregatesVisitor = GetAggregatesMatcher::Visitor;
inline void assertNoWindows(const ASTPtr & ast, const char * description)
{
GetAggregatesVisitor::Data data{.assert_no_windows = description};
GetAggregatesVisitor(data).visit(ast);
}
inline void assertNoAggregates(const ASTPtr & ast, const char * description)
{
GetAggregatesVisitor::Data data{description, {}, {}};
GetAggregatesVisitor::Data data{.assert_no_aggregates = description};
GetAggregatesVisitor(data).visit(ast);
}

View File

@ -380,11 +380,44 @@ std::vector<const ASTFunction *> getAggregates(ASTPtr & query, const ASTSelectQu
/// There can not be other aggregate functions within the aggregate functions.
for (const ASTFunction * node : data.aggregates)
{
for (auto & arg : node->arguments->children)
{
assertNoAggregates(arg, "inside another aggregate function");
assertNoWindows(arg, "inside another window function");
}
}
return data.aggregates;
}
std::vector<const ASTFunction *> getWindowFunctions(ASTPtr & query, const ASTSelectQuery & select_query)
{
/// There can not be aggregate functions inside the WHERE and PREWHERE.
if (select_query.where())
assertNoWindows(select_query.where(), "in WHERE");
if (select_query.prewhere())
assertNoWindows(select_query.prewhere(), "in PREWHERE");
GetAggregatesVisitor::Data data;
GetAggregatesVisitor(data).visit(query);
/// There can not be other aggregate functions within the aggregate functions.
for (const ASTFunction * node : data.window_functions)
{
for (auto & arg : node->arguments->children)
{
//assertNoAggregates(arg, "inside another aggregate function");
assertNoWindows(arg, "inside another window function");
}
}
fmt::print(stderr, "getWindowFunctions ({}) for \n{}\n at \n{}\n",
data.window_functions.size(), query->formatForErrorMessage(),
StackTrace().toString());
return data.window_functions;
}
}
TreeRewriterResult::TreeRewriterResult(
@ -398,6 +431,8 @@ TreeRewriterResult::TreeRewriterResult(
{
collectSourceColumns(add_special);
is_remote_storage = storage && storage->isRemote();
fmt::print(stderr, "TreeRewriterResult created at \n{}\n", StackTrace().toString());
}
/// Add columns from storage to source_columns list. Deduplicate resulted list.
@ -668,6 +703,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
collectJoinedColumns(*result.analyzed_join, *select_query, tables_with_columns, result.aliases);
result.aggregates = getAggregates(query, *select_query);
result.window_functions = getWindowFunctions(query, *select_query);
result.collectUsedColumns(query, true);
result.ast_join = select_query->join();

View File

@ -35,6 +35,8 @@ struct TreeRewriterResult
Aliases aliases;
std::vector<const ASTFunction *> aggregates;
std::vector<const ASTFunction *> window_functions;
/// Which column is needed to be ARRAY-JOIN'ed to get the specified.
/// For example, for `SELECT s.v ... ARRAY JOIN a AS s` will get "s.v" -> "a.v".
NameToNameMap array_join_result_to_source;

View File

@ -157,4 +157,11 @@ void IAST::dumpTree(WriteBuffer & ostr, size_t indent) const
child->dumpTree(ostr, indent + 1);
}
std::string IAST::dumpTree(size_t indent) const
{
WriteBufferFromOwnString wb;
dumpTree(wb, indent);
return wb.str();
}
}

View File

@ -77,6 +77,7 @@ public:
virtual void updateTreeHashImpl(SipHash & hash_state) const;
void dumpTree(WriteBuffer & ostr, size_t indent = 0) const;
std::string dumpTree(size_t indent = 0) const;
/** Check the depth of the tree.
* If max_depth is specified and the depth is greater - throw an exception.

View File

@ -63,6 +63,9 @@ FilterTransform::FilterTransform(
auto & column = transformed_header.getByPosition(filter_column_position).column;
if (column)
constant_filter_description = ConstantFilterDescription(*column);
fmt::print(stderr, "FilterTransform created at \n{}\n",
StackTrace().toString());
}
IProcessor::Status FilterTransform::prepare()