Refactor ExpressionActions [Part 3]

This commit is contained in:
Nikolai Kochetov 2020-11-03 14:28:28 +03:00
parent d9d83d8db6
commit 07a7c46b89
54 changed files with 1570 additions and 1678 deletions

View File

@ -40,7 +40,7 @@ Block::Block(const ColumnsWithTypeAndName & data_) : data{data_}
void Block::initializeIndexByName()
{
for (size_t i = 0, size = data.size(); i < size; ++i)
index_by_name[data[i].name] = i;
index_by_name.emplace(data[i].name, i);
}
@ -295,6 +295,20 @@ std::string Block::dumpStructure() const
return out.str();
}
std::string Block::dumpIndex() const
{
WriteBufferFromOwnString out;
bool first = true;
for (const auto & [name, pos] : index_by_name)
{
if (!first)
out << ", ";
first = false;
out << name << ' ' << pos;
}
return out.str();
}
Block Block::cloneEmpty() const
{

View File

@ -119,6 +119,9 @@ public:
/** List of names, types and lengths of columns. Designed for debugging. */
std::string dumpStructure() const;
/** List of column names and positions from index */
std::string dumpIndex() const;
/** Get the same block, but empty. */
Block cloneEmpty() const;

View File

@ -106,11 +106,11 @@ std::ostream & operator<<(std::ostream & stream, const Packet & what)
return stream;
}
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what)
{
stream << "ExpressionAction(" << what.toString() << ")";
return stream;
}
//std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what)
//{
// stream << "ExpressionAction(" << what.toString() << ")";
// return stream;
//}
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what)
{

View File

@ -40,8 +40,8 @@ std::ostream & operator<<(std::ostream & stream, const IColumn & what);
struct Packet;
std::ostream & operator<<(std::ostream & stream, const Packet & what);
struct ExpressionAction;
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what);
//struct ExpressionAction;
//std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what);
class ExpressionActions;
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what);

View File

@ -46,7 +46,7 @@ void CheckConstraintsBlockOutputStream::write(const Block & block)
auto * constraint_ptr = constraints.constraints[i]->as<ASTConstraintDeclaration>();
ColumnWithTypeAndName res_column = block_to_calculate.getByPosition(block_to_calculate.columns() - 1);
ColumnWithTypeAndName res_column = block_to_calculate.getByName(constraint_ptr->expr->getColumnName());
if (!isUInt8(res_column.type))
throw Exception(ErrorCodes::LOGICAL_ERROR, "Constraint {} does not return a value of type UInt8",

View File

@ -201,7 +201,7 @@ public:
{
/// Check that expression does not contain unusual actions that will break columnss structure.
for (const auto & action : expression_actions->getActions())
if (action.type == ExpressionAction::Type::ARRAY_JOIN)
if (action.node->type == ActionsDAG::Type::ARRAY_JOIN)
throw Exception("Expression with arrayJoin or other unusual action cannot be captured", ErrorCodes::BAD_ARGUMENTS);
std::unordered_map<std::string, DataTypePtr> arguments_map;

View File

@ -350,7 +350,7 @@ SetPtr makeExplicitSet(
auto it = index.find(left_arg->getColumnName());
if (it == index.end())
throw Exception("Unknown identifier: '" + left_arg->getColumnName() + "'", ErrorCodes::UNKNOWN_IDENTIFIER);
const DataTypePtr & left_arg_type = it->second->result_type;
const DataTypePtr & left_arg_type = (*it)->result_type;
DataTypes set_element_types = {left_arg_type};
const auto * left_tuple_type = typeid_cast<const DataTypeTuple *>(left_arg_type.get());
@ -404,7 +404,7 @@ ActionsMatcher::Data::Data(
bool ActionsMatcher::Data::hasColumn(const String & column_name) const
{
return actions_stack.getLastActions().getIndex().count(column_name) != 0;
return actions_stack.getLastActions().getIndex().contains(column_name);
}
ScopeStack::ScopeStack(ActionsDAGPtr actions, const Context & context_)
@ -413,9 +413,9 @@ ScopeStack::ScopeStack(ActionsDAGPtr actions, const Context & context_)
auto & level = stack.emplace_back();
level.actions = std::move(actions);
for (const auto & [name, node] : level.actions->getIndex())
for (const auto & node : level.actions->getIndex())
if (node->type == ActionsDAG::Type::INPUT)
level.inputs.emplace(name);
level.inputs.emplace(node->result_name);
}
void ScopeStack::pushLevel(const NamesAndTypesList & input_columns)
@ -432,9 +432,9 @@ void ScopeStack::pushLevel(const NamesAndTypesList & input_columns)
const auto & index = level.actions->getIndex();
for (const auto & [name, node] : prev.actions->getIndex())
for (const auto & node : prev.actions->getIndex())
{
if (index.count(name) == 0)
if (!index.contains(node->result_name))
level.actions->addInput({node->column, node->result_type, node->result_name});
}
}
@ -451,7 +451,7 @@ size_t ScopeStack::getColumnLevel(const std::string & name)
const auto & index = stack[i].actions->getIndex();
auto it = index.find(name);
if (it != index.end() && it->second->type != ActionsDAG::Type::INPUT)
if (it != index.end() && (*it)->type != ActionsDAG::Type::INPUT)
return i;
}
@ -475,15 +475,15 @@ void ScopeStack::addAlias(const std::string & name, std::string alias)
stack[j].actions->addInput({node.column, node.result_type, node.result_name});
}
void ScopeStack::addArrayJoin(const std::string & source_name, std::string result_name, std::string unique_column_name)
void ScopeStack::addArrayJoin(const std::string & source_name, std::string result_name)
{
getColumnLevel(source_name);
if (stack.front().actions->getIndex().count(source_name) == 0)
if (!stack.front().actions->getIndex().contains(source_name))
throw Exception("Expression with arrayJoin cannot depend on lambda argument: " + source_name,
ErrorCodes::BAD_ARGUMENTS);
const auto & node = stack.front().actions->addArrayJoin(source_name, std::move(result_name), std::move(unique_column_name));
const auto & node = stack.front().actions->addArrayJoin(source_name, std::move(result_name));
for (size_t j = 1; j < stack.size(); ++j)
stack[j].actions->addInput({node.column, node.result_type, node.result_name});
@ -492,14 +492,13 @@ void ScopeStack::addArrayJoin(const std::string & source_name, std::string resul
void ScopeStack::addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name,
bool compile_expressions)
std::string result_name)
{
size_t level = 0;
for (const auto & argument : argument_names)
level = std::max(level, getColumnLevel(argument));
const auto & node = stack[level].actions->addFunction(function, argument_names, std::move(result_name), compile_expressions);
const auto & node = stack[level].actions->addFunction(function, argument_names, std::move(result_name), context);
for (size_t j = level + 1; j < stack.size(); ++j)
stack[j].actions->addInput({node.column, node.result_type, node.result_name});
@ -746,7 +745,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
auto it = index.find(child_column_name);
if (it != index.end())
{
argument_types.push_back(it->second->result_type);
argument_types.push_back((*it)->result_type);
argument_names.push_back(child_column_name);
}
else
@ -792,10 +791,12 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
data.actions_stack.pushLevel(lambda_arguments);
visit(lambda->arguments->children.at(1), data);
auto lambda_dag = data.actions_stack.popLevel();
auto lambda_actions = lambda_dag->buildExpressions(data.context);
String result_name = lambda->arguments->children.at(1)->getColumnName();
lambda_actions->finalize(Names(1, result_name));
lambda_dag->removeUnusedActions(Names(1, result_name));
auto lambda_actions = lambda_dag->buildExpressions();
DataTypePtr result_type = lambda_actions->getSampleBlock().getByName(result_name).type;
Names captured;
@ -853,7 +854,7 @@ void ActionsMatcher::visit(const ASTLiteral & literal, const ASTPtr & /* ast */,
auto it = index.find(default_name);
if (it != index.end())
existing_column = it->second;
existing_column = *it;
/*
* To approximate CSE, bind all identical literals to a single temporary
@ -964,7 +965,7 @@ SetPtr ActionsMatcher::makeSet(const ASTFunction & node, Data & data, bool no_su
{
const auto & last_actions = data.actions_stack.getLastActions();
const auto & index = last_actions.getIndex();
if (index.count(left_in_operand->getColumnName()) != 0)
if (index.contains(left_in_operand->getColumnName()))
/// An explicit enumeration of values in parentheses.
return makeExplicitSet(&node, last_actions, false, data.context, data.set_size_limit, data.prepared_sets);
else

View File

@ -12,7 +12,6 @@ namespace DB
class Context;
class ASTFunction;
struct ExpressionAction;
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
@ -83,12 +82,11 @@ struct ScopeStack
void addColumn(ColumnWithTypeAndName column);
void addAlias(const std::string & name, std::string alias);
void addArrayJoin(const std::string & source_name, std::string result_name, std::string unique_column_name);
void addArrayJoin(const std::string & source_name, std::string result_name);
void addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name,
bool compile_expressions);
std::string result_name);
ActionsDAGPtr popLevel();
@ -147,15 +145,14 @@ public:
void addArrayJoin(const std::string & source_name, std::string result_name)
{
actions_stack.addArrayJoin(source_name, std::move(result_name), getUniqueName("_array_join_" + source_name));
actions_stack.addArrayJoin(source_name, std::move(result_name));
}
void addFunction(const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name)
{
actions_stack.addFunction(function, argument_names, std::move(result_name),
context.getSettingsRef().compile_expressions);
actions_stack.addFunction(function, argument_names, std::move(result_name));
}
ActionsDAGPtr getActions()

View File

@ -112,7 +112,22 @@ Block Aggregator::Params::getHeader(
{
Block res;
if (src_header)
if (intermediate_header)
{
res = intermediate_header.cloneEmpty();
if (final)
{
for (const auto & aggregate : aggregates)
{
auto & elem = res.getByName(aggregate.column_name);
elem.type = aggregate.function->getReturnType();
elem.column = elem.type->createColumn();
}
}
}
else
{
for (const auto & key : keys)
res.insert(src_header.safeGetByPosition(key).cloneEmpty());
@ -133,21 +148,6 @@ Block Aggregator::Params::getHeader(
res.insert({ type, aggregate.column_name });
}
}
else if (intermediate_header)
{
res = intermediate_header.cloneEmpty();
if (final)
{
for (const auto & aggregate : aggregates)
{
auto & elem = res.getByName(aggregate.column_name);
elem.type = aggregate.function->getReturnType();
elem.column = elem.type->createColumn();
}
}
}
return materializeBlock(res);
}

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@
#include <Core/ColumnWithTypeAndName.h>
#include <Core/Names.h>
#include <Core/Settings.h>
#include <Core/ColumnNumbers.h>
#include <Common/SipHash.h>
#include <Common/UInt128.h>
#include <unordered_map>
@ -11,6 +12,11 @@
#include <Parsers/ASTTablesInSelectQuery.h>
#include <DataTypes/DataTypeArray.h>
#include <boost/multi_index_container.hpp>
#include <boost/multi_index/sequenced_index.hpp>
#include <boost/multi_index/hashed_index.hpp>
#include <boost/multi_index/identity.hpp>
#include <variant>
#if !defined(ARCADIA_BUILD)
@ -49,97 +55,12 @@ class CompiledExpressionCache;
class ArrayJoinAction;
using ArrayJoinActionPtr = std::shared_ptr<ArrayJoinAction>;
/** Action on the block.
*/
struct ExpressionAction
{
private:
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
public:
enum Type
{
ADD_COLUMN,
REMOVE_COLUMN,
COPY_COLUMN,
APPLY_FUNCTION,
/// Replaces the source column with array into column with elements.
/// Duplicates the values in the remaining columns by the number of elements in the arrays.
/// Source column is removed from block.
ARRAY_JOIN,
/// Reorder and rename the columns, delete the extra ones. The same column names are allowed in the result.
PROJECT,
/// Add columns with alias names. This columns are the same as non-aliased. PROJECT columns if you need to modify them.
ADD_ALIASES,
};
Type type{};
/// For ADD/REMOVE/ARRAY_JOIN/COPY_COLUMN.
std::string source_name;
std::string result_name;
DataTypePtr result_type;
/// If COPY_COLUMN can replace the result column.
bool can_replace = false;
/// For ADD_COLUMN.
ColumnPtr added_column;
/// For APPLY_FUNCTION.
/// OverloadResolver is used before action was added to ExpressionActions (when we don't know types of arguments).
FunctionOverloadResolverPtr function_builder;
/// Can be used after action was added to ExpressionActions if we want to get function signature or properties like monotonicity.
FunctionBasePtr function_base;
/// Prepared function which is used in function execution.
ExecutableFunctionPtr function;
Names argument_names;
bool is_function_compiled = false;
/// For JOIN
std::shared_ptr<const TableJoin> table_join;
JoinPtr join;
/// For PROJECT.
NamesWithAliases projection;
/// If result_name_ == "", as name "function_name(arguments separated by commas) is used".
static ExpressionAction applyFunction(
const FunctionOverloadResolverPtr & function_, const std::vector<std::string> & argument_names_, std::string result_name_ = "");
static ExpressionAction addColumn(const ColumnWithTypeAndName & added_column_);
static ExpressionAction removeColumn(const std::string & removed_name);
static ExpressionAction copyColumn(const std::string & from_name, const std::string & to_name, bool can_replace = false);
static ExpressionAction project(const NamesWithAliases & projected_columns_);
static ExpressionAction project(const Names & projected_columns_);
static ExpressionAction addAliases(const NamesWithAliases & aliased_columns_);
static ExpressionAction arrayJoin(std::string source_name, std::string result_name);
/// Which columns necessary to perform this action.
Names getNeededColumns() const;
std::string toString() const;
bool operator==(const ExpressionAction & other) const;
struct ActionHash
{
UInt128 operator()(const ExpressionAction & action) const;
};
private:
friend class ExpressionActions;
void prepare(Block & sample_block, const Settings & settings, NameSet & names_not_for_constant_folding);
void execute(Block & block, bool dry_run) const;
};
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ActionsDAG;
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
class ActionsDAG
{
public:
@ -160,16 +81,12 @@ public:
struct Node
{
std::vector<Node *> children;
/// This field is filled if current node is replaced by existing node with the same name.
Node * renaming_parent = nullptr;
Type type;
std::string result_name;
DataTypePtr result_type;
std::string unique_column_name_for_array_join;
FunctionOverloadResolverPtr function_builder;
/// Can be used after action was added to ExpressionActions if we want to get function signature or properties like monotonicity.
FunctionBasePtr function_base;
@ -185,12 +102,75 @@ public:
bool allow_constant_folding = true;
};
using Index = std::unordered_map<std::string_view, Node *>;
class Index
{
public:
Node *& operator[](std::string_view key)
{
auto res = map.emplace(key, list.end());
if (res.second)
res.first->second = list.emplace(list.end(), nullptr);
return *res.first->second;
}
void swap(Index & other)
{
list.swap(other.list);
map.swap(other.map);
}
auto size() const { return list.size(); }
bool contains(std::string_view key) const { return map.count(key) != 0; }
std::list<Node *>::iterator begin() { return list.begin(); }
std::list<Node *>::iterator end() { return list.end(); }
std::list<Node *>::const_iterator begin() const { return list.begin(); }
std::list<Node *>::const_iterator end() const { return list.end(); }
std::list<Node *>::const_iterator find(std::string_view key) const
{
auto it = map.find(key);
if (it == map.end())
return list.end();
return it->second;
}
/// Insert method doesn't check if map already have node with the same name.
/// If node with the same name exists, it is removed from map, but not list.
/// It is expected and used for project(), when result may have several columns with the same name.
void insert(Node * node) { map[node->result_name] = list.emplace(list.end(), node); }
void remove(Node * node)
{
auto it = map.find(node->result_name);
if (it != map.end())
return;
list.erase(it->second);
map.erase(it);
}
private:
std::list<Node *> list;
std::unordered_map<std::string_view, std::list<Node *>::iterator> map;
};
using Nodes = std::list<Node>;
private:
std::list<Node> nodes;
Nodes nodes;
Index index;
size_t max_temporary_columns = 0;
size_t max_temporary_non_const_columns = 0;
#if USE_EMBEDDED_COMPILER
std::shared_ptr<CompiledExpressionCache> compilation_cache;
#endif
bool project_input = false;
bool projected_output = false;
public:
ActionsDAG() = default;
ActionsDAG(const ActionsDAG &) = delete;
@ -198,58 +178,108 @@ public:
explicit ActionsDAG(const NamesAndTypesList & inputs);
explicit ActionsDAG(const ColumnsWithTypeAndName & inputs);
const Nodes & getNodes() const { return nodes; }
const Index & getIndex() const { return index; }
NamesAndTypesList getRequiredColumns() const;
ColumnsWithTypeAndName getResultColumns() const;
NamesAndTypesList getNamesAndTypesList() const;
Names getNames() const;
std::string dumpNames() const;
std::string dump() const;
std::string dumpDAG() const;
const Node & addInput(std::string name, DataTypePtr type);
const Node & addInput(ColumnWithTypeAndName column);
const Node & addColumn(ColumnWithTypeAndName column);
const Node & addAlias(const std::string & name, std::string alias, bool can_replace = false);
const Node & addArrayJoin(const std::string & source_name, std::string result_name, std::string unique_column_name);
const Node & addArrayJoin(const std::string & source_name, std::string result_name);
const Node & addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name,
bool compile_expressions);
const Context & context);
ExpressionActionsPtr buildExpressions(const Context & context);
/// Call addAlias several times.
void addAliases(const NamesWithAliases & aliases);
/// Adds alias actions and removes unused columns from index.
void project(const NamesWithAliases & projection);
/// Removes column from index.
void removeColumn(const std::string & column_name);
/// If column is not in index, try to find it in nodes and insert back into index.
bool tryRestoreColumn(const std::string & column_name);
void projectInput() { project_input = true; }
void removeUnusedActions(const Names & required_names);
ExpressionActionsPtr buildExpressions();
/// Splits actions into two parts. Returned half may be swapped with ARRAY JOIN.
/// Returns nullptr if no actions may be moved before ARRAY JOIN.
ActionsDAGPtr splitActionsBeforeArrayJoin(const NameSet & array_joined_columns);
bool hasArrayJoin() const;
bool empty() const;
bool projectedOutput() const { return projected_output; }
ActionsDAGPtr clone() const;
private:
Node & addNode(Node node, bool can_replace = false);
Node & getNode(const std::string & name);
ActionsDAGPtr cloneEmpty() const
{
auto actions = std::make_shared<ActionsDAG>();
actions->max_temporary_columns = max_temporary_columns;
actions->max_temporary_non_const_columns = max_temporary_non_const_columns;
#if USE_EMBEDDED_COMPILER
actions->compilation_cache = compilation_cache;
#endif
return actions;
}
ExpressionActionsPtr linearizeActions() const;
void removeUnusedActions(const std::vector<Node *> & required_nodes);
void addAliases(const NamesWithAliases & aliases, std::vector<Node *> & result_nodes);
};
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
/** Contains a sequence of actions on the block.
*/
class ExpressionActions
{
private:
public:
using Node = ActionsDAG::Node;
using Index = ActionsDAG::Index;
struct Argument
{
size_t pos;
bool remove;
};
using Arguments = std::vector<Argument>;
struct Action
{
Node * node;
ColumnNumbers arguments;
/// Columns which will be removed after actions is executed.
/// It is always a subset of arguments.
ColumnNumbers to_remove;
const Node * node;
Arguments arguments;
size_t result_position;
bool is_used_in_result;
std::string toString() const;
};
using Actions = std::vector<Action>;
private:
struct ExecutionContext
{
ColumnsWithTypeAndName & input_columns;
ColumnsWithTypeAndName & inputs;
ColumnsWithTypeAndName columns;
std::vector<ssize_t> inputs_pos;
size_t num_rows;
};
@ -258,19 +288,26 @@ private:
size_t num_columns;
NamesAndTypesList required_columns;
ColumnNumbers result_positions;
Block sample_block;
/// This flag means that all columns except input will be removed from block before execution.
bool project_input = false;
size_t max_temporary_non_const_columns = 0;
friend class ActionsDAG;
public:
~ExpressionActions();
ExpressionActions() = default;
ExpressionActions(const ExpressionActions &) = delete;
ExpressionActions & operator=(const ExpressionActions &) = delete;
ExpressionActions(const ExpressionActions & other) = default;
const Actions & getActions() const { return actions; }
/// Adds to the beginning the removal of all extra columns.
void prependProjectInput();
/// Splits actions into two parts. Returned half may be swapped with ARRAY JOIN.
/// Returns nullptr if no actions may be moved before ARRAY JOIN.
ExpressionActionsPtr splitActionsBeforeArrayJoin(const NameSet & array_joined_columns);
void projectInput() { project_input = true; }
/// - Adds actions to delete all but the specified columns.
/// - Removes unused input columns.
@ -286,6 +323,7 @@ public:
/// Execute the expression on the block. The block must contain all the columns returned by getRequiredColumns.
void execute(Block & block, bool dry_run = false) const;
void execute(Block & block, size_t & num_rows, bool dry_run = false) const;
bool hasArrayJoin() const;
@ -296,18 +334,13 @@ public:
static std::string getSmallestColumn(const NamesAndTypesList & columns);
const Settings & getSettings() const { return settings; }
/// Check if column is always zero. True if it's definite, false if we can't say for sure.
/// Call it only after subqueries for sets were executed.
bool checkColumnIsAlwaysFalse(const String & column_name) const;
private:
ExpressionActionsPtr clone() const;
Settings settings;
#if USE_EMBEDDED_COMPILER
std::shared_ptr<CompiledExpressionCache> compilation_cache;
#endif
private:
void checkLimits(ExecutionContext & execution_context) const;
@ -343,8 +376,8 @@ struct ExpressionActionsChain
/// If not empty, has the same size with required_output; is filled in finalize().
std::vector<bool> can_remove_required_output;
virtual const NamesAndTypesList & getRequiredColumns() const = 0;
virtual const ColumnsWithTypeAndName & getResultColumns() const = 0;
virtual NamesAndTypesList getRequiredColumns() const = 0;
virtual ColumnsWithTypeAndName getResultColumns() const = 0;
/// Remove unused result and update required columns
virtual void finalize(const Names & required_output_) = 0;
/// Add projections to expression
@ -354,43 +387,42 @@ struct ExpressionActionsChain
/// Only for ExpressionActionsStep
ActionsDAGPtr & actions();
const ActionsDAGPtr & actions() const;
ExpressionActionsPtr getExpression() const;
};
struct ExpressionActionsStep : public Step
{
ActionsDAGPtr actions_dag;
ExpressionActionsPtr actions;
ActionsDAGPtr actions;
explicit ExpressionActionsStep(ActionsDAGPtr actions_, Names required_output_ = Names())
: Step(std::move(required_output_))
, actions_dag(std::move(actions_))
, actions(std::move(actions_))
{
}
const NamesAndTypesList & getRequiredColumns() const override
NamesAndTypesList getRequiredColumns() const override
{
return actions->getRequiredColumnsWithTypes();
return actions->getRequiredColumns();
}
const ColumnsWithTypeAndName & getResultColumns() const override
ColumnsWithTypeAndName getResultColumns() const override
{
return actions->getSampleBlock().getColumnsWithTypeAndName();
return actions->getResultColumns();
}
void finalize(const Names & required_output_) override
{
actions->finalize(required_output_);
if (!actions->projectedOutput())
actions->removeUnusedActions(required_output_);
}
void prependProjectInput() const override
{
actions->prependProjectInput();
actions->projectInput();
}
std::string dump() const override
{
return actions->dumpActions();
return actions->dump();
}
};
@ -402,8 +434,8 @@ struct ExpressionActionsChain
ArrayJoinStep(ArrayJoinActionPtr array_join_, ColumnsWithTypeAndName required_columns_);
const NamesAndTypesList & getRequiredColumns() const override { return required_columns; }
const ColumnsWithTypeAndName & getResultColumns() const override { return result_columns; }
NamesAndTypesList getRequiredColumns() const override { return required_columns; }
ColumnsWithTypeAndName getResultColumns() const override { return result_columns; }
void finalize(const Names & required_output_) override;
void prependProjectInput() const override {} /// TODO: remove unused columns before ARRAY JOIN ?
std::string dump() const override { return "ARRAY JOIN"; }
@ -418,8 +450,8 @@ struct ExpressionActionsChain
ColumnsWithTypeAndName result_columns;
JoinStep(std::shared_ptr<TableJoin> analyzed_join_, JoinPtr join_, ColumnsWithTypeAndName required_columns_);
const NamesAndTypesList & getRequiredColumns() const override { return required_columns; }
const ColumnsWithTypeAndName & getResultColumns() const override { return result_columns; }
NamesAndTypesList getRequiredColumns() const override { return required_columns; }
ColumnsWithTypeAndName getResultColumns() const override { return result_columns; }
void finalize(const Names & required_output_) override;
void prependProjectInput() const override {} /// TODO: remove unused columns before JOIN ?
std::string dump() const override { return "JOIN"; }
@ -431,7 +463,7 @@ struct ExpressionActionsChain
const Context & context;
Steps steps;
void addStep();
void addStep(NameSet non_constant_inputs = {});
void finalize();
@ -440,7 +472,7 @@ struct ExpressionActionsChain
steps.clear();
}
ExpressionActionsPtr getLastActions(bool allow_empty = false)
ActionsDAGPtr getLastActions(bool allow_empty = false)
{
if (steps.empty())
{
@ -449,9 +481,7 @@ struct ExpressionActionsChain
throw Exception("Empty ExpressionActionsChain", ErrorCodes::LOGICAL_ERROR);
}
auto * step = typeid_cast<ExpressionActionsStep *>(steps.back().get());
step->actions = step->actions_dag->buildExpressions(context);
return step->actions;
return typeid_cast<ExpressionActionsStep *>(steps.back().get())->actions;
}
Step & getLastStep()

View File

@ -70,16 +70,16 @@ namespace
/// Check if there is an ignore function. It's used for disabling constant folding in query
/// predicates because some performance tests use ignore function as a non-optimize guard.
bool allowEarlyConstantFolding(const ExpressionActions & actions, const Settings & settings)
bool allowEarlyConstantFolding(const ActionsDAG & actions, const Settings & settings)
{
if (!settings.enable_early_constant_folding)
return false;
for (const auto & action : actions.getActions())
for (const auto & node : actions.getNodes())
{
if (action.type == action.APPLY_FUNCTION && action.function_base)
if (node.type == ActionsDAG::Type::FUNCTION && node.function_base)
{
auto name = action.function_base->getName();
auto name = node.function_base->getName();
if (name == "ignore")
return false;
}
@ -234,7 +234,7 @@ void ExpressionAnalyzer::analyzeAggregation()
if (it == index.end())
throw Exception("Unknown identifier (in GROUP BY): " + column_name, ErrorCodes::UNKNOWN_IDENTIFIER);
const auto & node = it->second;
const auto & node = *it;
/// Constant expressions have non-null column pointer at this stage.
if (node->column && isColumnConst(*node->column))
@ -382,7 +382,7 @@ void SelectQueryExpressionAnalyzer::makeSetsForIndex(const ASTPtr & node)
auto temp_actions = std::make_shared<ActionsDAG>(columns_after_join);
getRootActions(left_in_operand, true, temp_actions);
if (temp_actions->getIndex().count(left_in_operand->getColumnName()) != 0)
if (temp_actions->getIndex().contains(left_in_operand->getColumnName()))
makeExplicitSet(func, *temp_actions, true, context,
settings.size_limits_for_set, prepared_sets);
}
@ -434,7 +434,7 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ActionsDAGPtr & actions)
if (it == index.end())
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, "Unknown identifier (in aggregate function '{}'): {}", node->name, name);
types[i] = it->second->result_type;
types[i] = (*it)->result_type;
aggregate.argument_names[i] = name;
}
@ -481,7 +481,7 @@ ArrayJoinActionPtr ExpressionAnalyzer::addMultipleArrayJoinAction(ActionsDAGPtr
return std::make_shared<ArrayJoinAction>(result_columns, array_join_is_left, context);
}
ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, ExpressionActionsPtr & before_array_join, bool only_types)
ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, ActionsDAGPtr & before_array_join, bool only_types)
{
const auto * select_query = getSelectQuery();
@ -637,11 +637,11 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(const ASTTablesInSelectQuer
return subquery_for_join.join;
}
ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere(
ActionsDAGPtr SelectQueryExpressionAnalyzer::appendPrewhere(
ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns)
{
const auto * select_query = getSelectQuery();
ExpressionActionsPtr prewhere_actions;
ActionsDAGPtr prewhere_actions;
if (!select_query->prewhere())
return prewhere_actions;
@ -652,7 +652,7 @@ ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere(
step.required_output.push_back(prewhere_column_name);
step.can_remove_required_output.push_back(true);
auto filter_type = step.actions()->getIndex().find(prewhere_column_name)->second->result_type;
auto filter_type = (*step.actions()->getIndex().find(prewhere_column_name))->result_type;
if (!filter_type->canBeUsedInBooleanContext())
throw Exception("Invalid type for filter in PREWHERE: " + filter_type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER);
@ -661,8 +661,8 @@ ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere(
/// Remove unused source_columns from prewhere actions.
auto tmp_actions_dag = std::make_shared<ActionsDAG>(sourceColumns());
getRootActions(select_query->prewhere(), only_types, tmp_actions_dag);
auto tmp_actions = tmp_actions_dag->buildExpressions(context);
tmp_actions->finalize({prewhere_column_name});
tmp_actions_dag->removeUnusedActions({prewhere_column_name});
auto tmp_actions = tmp_actions_dag->buildExpressions();
auto required_columns = tmp_actions->getRequiredColumns();
NameSet required_source_columns(required_columns.begin(), required_columns.end());
@ -686,7 +686,7 @@ ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere(
Names required_output(name_set.begin(), name_set.end());
prewhere_actions = chain.getLastActions();
prewhere_actions->finalize(required_output);
prewhere_actions->removeUnusedActions(required_output);
}
{
@ -697,11 +697,14 @@ ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere(
/// 2. Store side columns which were calculated during prewhere actions execution if they are used.
/// Example: select F(A) prewhere F(A) > 0. F(A) can be saved from prewhere step.
/// 3. Check if we can remove filter column at prewhere step. If we can, action will store single REMOVE_COLUMN.
ColumnsWithTypeAndName columns = prewhere_actions->getSampleBlock().getColumnsWithTypeAndName();
ColumnsWithTypeAndName columns = prewhere_actions->getResultColumns();
auto required_columns = prewhere_actions->getRequiredColumns();
NameSet prewhere_input_names(required_columns.begin(), required_columns.end());
NameSet prewhere_input_names;
NameSet unused_source_columns;
for (const auto & col : required_columns)
prewhere_input_names.insert(col.name);
for (const auto & column : sourceColumns())
{
if (prewhere_input_names.count(column.name) == 0)
@ -721,7 +724,7 @@ ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere(
return prewhere_actions;
}
void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name)
void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsChain & chain, ActionsDAGPtr actions, String column_name)
{
ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
@ -749,7 +752,7 @@ bool SelectQueryExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain,
getRootActions(select_query->where(), only_types, step.actions());
auto filter_type = step.actions()->getIndex().find(where_column_name)->second->result_type;
auto filter_type = (*step.actions()->getIndex().find(where_column_name))->result_type;
if (!filter_type->canBeUsedInBooleanContext())
throw Exception("Invalid type for filter in WHERE: " + filter_type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER);
@ -780,7 +783,7 @@ bool SelectQueryExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain
{
auto actions_dag = std::make_shared<ActionsDAG>(columns_after_join);
getRootActions(child, only_types, actions_dag);
group_by_elements_actions.emplace_back(actions_dag->buildExpressions(context));
group_by_elements_actions.emplace_back(actions_dag->buildExpressions());
}
}
@ -842,18 +845,24 @@ void SelectQueryExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain,
step.required_output.push_back(child->getColumnName());
}
bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order,
ActionsDAGPtr SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order,
ManyExpressionActions & order_by_elements_actions)
{
const auto * select_query = getSelectQuery();
if (!select_query->orderBy())
return false;
{
auto actions = chain.getLastActions();
chain.addStep();
return actions;
}
ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
getRootActions(select_query->orderBy(), only_types, step.actions());
bool with_fill = false;
NameSet order_by_keys;
for (auto & child : select_query->orderBy()->children)
{
const auto * ast = child->as<ASTOrderByElement>();
@ -861,6 +870,9 @@ bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain
throw Exception("Bad order expression AST", ErrorCodes::UNKNOWN_TYPE_OF_AST_NODE);
ASTPtr order_expression = ast->children.at(0);
step.required_output.push_back(order_expression->getColumnName());
if (ast->with_fill)
with_fill = true;
}
if (optimize_read_in_order)
@ -869,10 +881,21 @@ bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain
{
auto actions_dag = std::make_shared<ActionsDAG>(columns_after_join);
getRootActions(child, only_types, actions_dag);
order_by_elements_actions.emplace_back(actions_dag->buildExpressions(context));
order_by_elements_actions.emplace_back(actions_dag->buildExpressions());
}
}
return true;
NameSet non_constant_inputs;
if (with_fill)
{
for (const auto & column : step.getResultColumns())
if (!order_by_keys.count(column.name))
non_constant_inputs.insert(column.name);
}
auto actions = chain.getLastActions();
chain.addStep(non_constant_inputs);
return actions;
}
bool SelectQueryExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain, bool only_types)
@ -903,7 +926,7 @@ bool SelectQueryExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain
return true;
}
ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain & chain) const
ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain & chain) const
{
const auto * select_query = getSelectQuery();
@ -950,7 +973,7 @@ ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendProjectResult(Expressi
}
auto actions = chain.getLastActions();
actions->add(ExpressionAction::project(result_columns));
actions->project(result_columns);
return actions;
}
@ -963,7 +986,7 @@ void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const
}
ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool project_result)
ActionsDAGPtr ExpressionAnalyzer::getActionsDAG(bool add_aliases, bool project_result)
{
auto actions_dag = std::make_shared<ActionsDAG>(aggregated_columns);
NamesWithAliases result_columns;
@ -989,14 +1012,12 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje
getRootActions(ast, false, actions_dag);
}
auto actions = actions_dag->buildExpressions(context);
if (add_aliases)
{
if (project_result)
actions->add(ExpressionAction::project(result_columns));
actions_dag->project(result_columns);
else
actions->add(ExpressionAction::addAliases(result_columns));
actions_dag->addAliases(result_columns);
}
if (!(add_aliases && project_result))
@ -1006,9 +1027,13 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje
result_names.push_back(column_name_type.name);
}
actions->finalize(result_names);
actions_dag->removeUnusedActions(result_names);
return actions_dag;
}
return actions;
ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool project_result)
{
return getActionsDAG(add_aliases, project_result)->buildExpressions();
}
@ -1017,10 +1042,10 @@ ExpressionActionsPtr ExpressionAnalyzer::getConstActions()
auto actions = std::make_shared<ActionsDAG>(NamesAndTypesList());
getRootActions(query, true, actions, true);
return actions->buildExpressions(context);
return actions->buildExpressions();
}
ExpressionActionsPtr SelectQueryExpressionAnalyzer::simpleSelectActions()
ActionsDAGPtr SelectQueryExpressionAnalyzer::simpleSelectActions()
{
ExpressionActionsChain new_chain(context);
appendSelect(new_chain, false);
@ -1061,7 +1086,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
if (!finalized)
{
finalize(chain, context, where_step_num);
finalize(chain, where_step_num);
finalized = true;
}
@ -1107,7 +1132,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
Block before_prewhere_sample = source_header;
if (sanitizeBlock(before_prewhere_sample))
{
prewhere_info->prewhere_actions->execute(before_prewhere_sample);
prewhere_info->prewhere_actions->buildExpressions()->execute(before_prewhere_sample);
auto & column_elem = before_prewhere_sample.getByName(query.prewhere()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
@ -1140,7 +1165,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
before_where_sample = source_header;
if (sanitizeBlock(before_where_sample))
{
before_where->execute(before_where_sample);
before_where->buildExpressions()->execute(before_where_sample);
auto & column_elem = before_where_sample.getByName(query.where()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
@ -1188,10 +1213,12 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
/// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers.
query_analyzer.appendSelect(chain, only_types || (need_aggregate ? !second_stage : !first_stage));
selected_columns = chain.getLastStep().required_output;
has_order_by = query_analyzer.appendOrderBy(chain, only_types || (need_aggregate ? !second_stage : !first_stage),
optimize_read_in_order, order_by_elements_actions);
before_order_and_select = chain.getLastActions();
chain.addStep();
has_order_by = query.orderBy() != nullptr;
before_order_and_select = query_analyzer.appendOrderBy(
chain,
only_types || (need_aggregate ? !second_stage : !first_stage),
optimize_read_in_order,
order_by_elements_actions);
if (query_analyzer.appendLimitBy(chain, only_types || !second_stage))
{
@ -1210,28 +1237,35 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
checkActions();
}
void ExpressionAnalysisResult::finalize(const ExpressionActionsChain & chain, const Context & context_, size_t where_step_num)
void ExpressionAnalysisResult::finalize(const ExpressionActionsChain & chain, size_t where_step_num)
{
if (hasPrewhere())
{
const ExpressionActionsChain::Step & step = *chain.steps.at(0);
prewhere_info->remove_prewhere_column = step.can_remove_required_output.at(0);
Names columns_to_remove;
NameSet columns_to_remove;
for (size_t i = 1; i < step.required_output.size(); ++i)
{
if (step.can_remove_required_output[i])
columns_to_remove.push_back(step.required_output[i]);
columns_to_remove.insert(step.required_output[i]);
}
if (!columns_to_remove.empty())
{
auto columns = prewhere_info->prewhere_actions->getSampleBlock().getNamesAndTypesList();
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(columns, context_);
for (const auto & column : columns_to_remove)
actions->add(ExpressionAction::removeColumn(column));
auto columns = prewhere_info->prewhere_actions->getResultColumns();
prewhere_info->remove_columns_actions = std::move(actions);
auto remove_actions = std::make_shared<ActionsDAG>();
for (const auto & column : columns)
{
if (columns_to_remove.count(column.name))
{
remove_actions->addInput(column);
remove_actions->removeColumn(column.name);
}
}
prewhere_info->remove_columns_actions = std::move(remove_actions);
}
columns_to_remove_after_prewhere = std::move(columns_to_remove);
@ -1248,11 +1282,11 @@ void ExpressionAnalysisResult::finalize(const ExpressionActionsChain & chain, co
void ExpressionAnalysisResult::removeExtraColumns() const
{
if (hasFilter())
filter_info->actions->prependProjectInput();
filter_info->actions->projectInput();
if (hasWhere())
before_where->prependProjectInput();
before_where->projectInput();
if (hasHaving())
before_having->prependProjectInput();
before_having->projectInput();
}
void ExpressionAnalysisResult::checkActions() const
@ -1260,11 +1294,11 @@ void ExpressionAnalysisResult::checkActions() const
/// Check that PREWHERE doesn't contain unusual actions. Unusual actions are that can change number of rows.
if (hasPrewhere())
{
auto check_actions = [](const ExpressionActionsPtr & actions)
auto check_actions = [](const ActionsDAGPtr & actions)
{
if (actions)
for (const auto & action : actions->getActions())
if (action.type == ExpressionAction::Type::ARRAY_JOIN)
for (const auto & node : actions->getNodes())
if (node.type == ActionsDAG::Type::ARRAY_JOIN)
throw Exception("PREWHERE cannot contain ARRAY JOIN action", ErrorCodes::ILLEGAL_PREWHERE);
};

View File

@ -102,6 +102,7 @@ public:
/// If add_aliases, only the calculated values in the desired order and add aliases.
/// If also project_result, than only aliases remain in the output block.
/// Otherwise, only temporary columns will be deleted from the block.
ActionsDAGPtr getActionsDAG(bool add_aliases, bool project_result = true);
ExpressionActionsPtr getActions(bool add_aliases, bool project_result = true);
/// Actions that can be performed on an empty block: adding constants and applying functions that depend only on constants.
@ -182,22 +183,22 @@ struct ExpressionAnalysisResult
bool optimize_aggregation_in_order = false;
bool join_has_delayed_stream = false;
ExpressionActionsPtr before_array_join;
ActionsDAGPtr before_array_join;
ArrayJoinActionPtr array_join;
ExpressionActionsPtr before_join;
ActionsDAGPtr before_join;
JoinPtr join;
ExpressionActionsPtr before_where;
ExpressionActionsPtr before_aggregation;
ExpressionActionsPtr before_having;
ExpressionActionsPtr before_order_and_select;
ExpressionActionsPtr before_limit_by;
ExpressionActionsPtr final_projection;
ActionsDAGPtr before_where;
ActionsDAGPtr before_aggregation;
ActionsDAGPtr before_having;
ActionsDAGPtr before_order_and_select;
ActionsDAGPtr before_limit_by;
ActionsDAGPtr final_projection;
/// Columns from the SELECT list, before renaming them to aliases.
Names selected_columns;
/// Columns will be removed after prewhere actions execution.
Names columns_to_remove_after_prewhere;
NameSet columns_to_remove_after_prewhere;
PrewhereInfoPtr prewhere_info;
FilterInfoPtr filter_info;
@ -229,7 +230,7 @@ struct ExpressionAnalysisResult
void removeExtraColumns() const;
void checkActions() const;
void finalize(const ExpressionActionsChain & chain, const Context & context, size_t where_step_num);
void finalize(const ExpressionActionsChain & chain, size_t where_step_num);
};
/// SelectQuery specific ExpressionAnalyzer part.
@ -267,12 +268,12 @@ public:
/// Tables that will need to be sent to remote servers for distributed query processing.
const TemporaryTablesMapping & getExternalTables() const { return external_tables; }
ExpressionActionsPtr simpleSelectActions();
ActionsDAGPtr simpleSelectActions();
/// These appends are public only for tests
void appendSelect(ExpressionActionsChain & chain, bool only_types);
/// Deletes all columns except mentioned by SELECT, arranges the remaining columns and renames them to aliases.
ExpressionActionsPtr appendProjectResult(ExpressionActionsChain & chain) const;
ActionsDAGPtr appendProjectResult(ExpressionActionsChain & chain) const;
private:
StorageMetadataPtr metadata_snapshot;
@ -315,14 +316,14 @@ private:
*/
/// Before aggregation:
ArrayJoinActionPtr appendArrayJoin(ExpressionActionsChain & chain, ExpressionActionsPtr & before_array_join, bool only_types);
ArrayJoinActionPtr appendArrayJoin(ExpressionActionsChain & chain, ActionsDAGPtr & before_array_join, bool only_types);
bool appendJoinLeftKeys(ExpressionActionsChain & chain, bool only_types);
JoinPtr appendJoin(ExpressionActionsChain & chain);
/// Add preliminary rows filtration. Actions are created in other expression analyzer to prevent any possible alias injection.
void appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name);
void appendPreliminaryFilter(ExpressionActionsChain & chain, ActionsDAGPtr actions, String column_name);
/// remove_filter is set in ExpressionActionsChain::finalize();
/// Columns in `additional_required_columns` will not be removed (they can be used for e.g. sampling or FINAL modifier).
ExpressionActionsPtr appendPrewhere(ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns);
ActionsDAGPtr appendPrewhere(ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns);
bool appendWhere(ExpressionActionsChain & chain, bool only_types);
bool appendGroupBy(ExpressionActionsChain & chain, bool only_types, bool optimize_aggregation_in_order, ManyExpressionActions &);
void appendAggregateFunctionsArguments(ExpressionActionsChain & chain, bool only_types);
@ -330,7 +331,7 @@ private:
/// After aggregation:
bool appendHaving(ExpressionActionsChain & chain, bool only_types);
/// appendSelect
bool appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order, ManyExpressionActions &);
ActionsDAGPtr appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order, ManyExpressionActions &);
bool appendLimitBy(ExpressionActionsChain & chain, bool only_types);
/// appendProjectResult
};

View File

@ -442,7 +442,7 @@ struct LLVMModuleState
};
LLVMFunction::LLVMFunction(const ExpressionActions::Actions & actions, const DB::Block & sample_block)
: name(actions.back().result_name)
: name(actions.back().node->result_name)
, module_state(std::make_unique<LLVMModuleState>())
{
LLVMContext context;
@ -452,21 +452,21 @@ LLVMFunction::LLVMFunction(const ExpressionActions::Actions & actions, const DB:
subexpressions[c.name] = subexpression(c.column, c.type);
for (const auto & action : actions)
{
const auto & names = action.argument_names;
const auto & types = action.function_base->getArgumentTypes();
const auto & children = action.node->children;
const auto & types = action.node->function_base->getArgumentTypes();
std::vector<CompilableExpression> args;
for (size_t i = 0; i < names.size(); ++i)
for (size_t i = 0; i < children.size(); ++i)
{
auto inserted = subexpressions.emplace(names[i], subexpression(arg_names.size()));
auto inserted = subexpressions.emplace(children[i]->result_name, subexpression(arg_names.size()));
if (inserted.second)
{
arg_names.push_back(names[i]);
arg_names.push_back(children[i]->result_name);
arg_types.push_back(types[i]);
}
args.push_back(inserted.first->second);
}
subexpressions[action.result_name] = subexpression(*action.function_base, std::move(args));
originals.push_back(action.function_base);
subexpressions[action.node->result_name] = subexpression(*action.node->function_base, std::move(args));
originals.push_back(action.node->function_base);
}
compileFunctionToLLVMByteCode(context, *this);
context.compileAllFunctionsToNativeCode();
@ -555,155 +555,155 @@ LLVMFunction::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataType
}
static bool isCompilable(const IFunctionBase & function)
{
if (!canBeNativeType(*function.getResultType()))
return false;
for (const auto & type : function.getArgumentTypes())
if (!canBeNativeType(*type))
return false;
return function.isCompilable();
}
//static bool isCompilable(const IFunctionBase & function)
//{
// if (!canBeNativeType(*function.getResultType()))
// return false;
// for (const auto & type : function.getArgumentTypes())
// if (!canBeNativeType(*type))
// return false;
// return function.isCompilable();
//}
static std::vector<std::unordered_set<std::optional<size_t>>> getActionsDependents(const ExpressionActions::Actions & actions, const Names & output_columns)
{
/// an empty optional is a poisoned value prohibiting the column's producer from being removed
/// (which it could be, if it was inlined into every dependent function).
std::unordered_map<std::string, std::unordered_set<std::optional<size_t>>> current_dependents;
for (const auto & name : output_columns)
current_dependents[name].emplace();
/// a snapshot of each compilable function's dependents at the time of its execution.
std::vector<std::unordered_set<std::optional<size_t>>> dependents(actions.size());
for (size_t i = actions.size(); i--;)
{
switch (actions[i].type)
{
case ExpressionAction::REMOVE_COLUMN:
current_dependents.erase(actions[i].source_name);
/// poison every other column used after this point so that inlining chains do not cross it.
for (auto & dep : current_dependents)
dep.second.emplace();
break;
case ExpressionAction::PROJECT:
current_dependents.clear();
for (const auto & proj : actions[i].projection)
current_dependents[proj.first].emplace();
break;
case ExpressionAction::ADD_ALIASES:
for (const auto & proj : actions[i].projection)
current_dependents[proj.first].emplace();
break;
case ExpressionAction::ADD_COLUMN:
case ExpressionAction::COPY_COLUMN:
case ExpressionAction::ARRAY_JOIN:
{
Names columns = actions[i].getNeededColumns();
for (const auto & column : columns)
current_dependents[column].emplace();
break;
}
case ExpressionAction::APPLY_FUNCTION:
{
dependents[i] = current_dependents[actions[i].result_name];
const bool compilable = isCompilable(*actions[i].function_base);
for (const auto & name : actions[i].argument_names)
{
if (compilable)
current_dependents[name].emplace(i);
else
current_dependents[name].emplace();
}
break;
}
}
}
return dependents;
}
void compileFunctions(
ExpressionActions::Actions & actions,
const Names & output_columns,
const Block & sample_block,
std::shared_ptr<CompiledExpressionCache> compilation_cache,
size_t min_count_to_compile_expression)
{
static std::unordered_map<UInt128, UInt32, UInt128Hash> counter;
static std::mutex mutex;
struct LLVMTargetInitializer
{
LLVMTargetInitializer()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
}
};
static LLVMTargetInitializer initializer;
auto dependents = getActionsDependents(actions, output_columns);
std::vector<ExpressionActions::Actions> fused(actions.size());
for (size_t i = 0; i < actions.size(); ++i)
{
if (actions[i].type != ExpressionAction::APPLY_FUNCTION || !isCompilable(*actions[i].function_base))
continue;
fused[i].push_back(actions[i]);
if (dependents[i].find({}) != dependents[i].end())
{
/// the result of compiling one function in isolation is pretty much the same as its `execute` method.
if (fused[i].size() == 1)
continue;
auto hash_key = ExpressionActions::ActionsHash{}(fused[i]);
{
std::lock_guard lock(mutex);
if (counter[hash_key]++ < min_count_to_compile_expression)
continue;
}
FunctionBasePtr fn;
if (compilation_cache)
{
std::tie(fn, std::ignore) = compilation_cache->getOrSet(hash_key, [&inlined_func=std::as_const(fused[i]), &sample_block] ()
{
Stopwatch watch;
FunctionBasePtr result_fn;
result_fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(inlined_func, sample_block));
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
return result_fn;
});
}
else
{
Stopwatch watch;
fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(fused[i], sample_block));
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
}
actions[i].function_base = fn;
actions[i].argument_names = typeid_cast<const LLVMFunction *>(typeid_cast<const FunctionBaseAdaptor *>(fn.get())->getImpl())->getArgumentNames();
actions[i].is_function_compiled = true;
continue;
}
/// TODO: determine whether it's profitable to inline the function if there's more than one dependent.
for (const auto & dep : dependents[i])
fused[*dep].insert(fused[*dep].end(), fused[i].begin(), fused[i].end());
}
for (auto & action : actions)
{
if (action.type == ExpressionAction::APPLY_FUNCTION && action.is_function_compiled)
action.function = action.function_base->prepare({}); /// Arguments are not used for LLVMFunction.
}
}
//static std::vector<std::unordered_set<std::optional<size_t>>> getActionsDependents(const ExpressionActions::Actions & actions, const Names & output_columns)
//{
// /// an empty optional is a poisoned value prohibiting the column's producer from being removed
// /// (which it could be, if it was inlined into every dependent function).
// std::unordered_map<std::string, std::unordered_set<std::optional<size_t>>> current_dependents;
// for (const auto & name : output_columns)
// current_dependents[name].emplace();
// /// a snapshot of each compilable function's dependents at the time of its execution.
// std::vector<std::unordered_set<std::optional<size_t>>> dependents(actions.size());
// for (size_t i = actions.size(); i--;)
// {
// switch (actions[i].type)
// {
// case ExpressionAction::REMOVE_COLUMN:
// current_dependents.erase(actions[i].source_name);
// /// poison every other column used after this point so that inlining chains do not cross it.
// for (auto & dep : current_dependents)
// dep.second.emplace();
// break;
//
// case ExpressionAction::PROJECT:
// current_dependents.clear();
// for (const auto & proj : actions[i].projection)
// current_dependents[proj.first].emplace();
// break;
//
// case ExpressionAction::ADD_ALIASES:
// for (const auto & proj : actions[i].projection)
// current_dependents[proj.first].emplace();
// break;
//
// case ExpressionAction::ADD_COLUMN:
// case ExpressionAction::COPY_COLUMN:
// case ExpressionAction::ARRAY_JOIN:
// {
// Names columns = actions[i].getNeededColumns();
// for (const auto & column : columns)
// current_dependents[column].emplace();
// break;
// }
//
// case ExpressionAction::APPLY_FUNCTION:
// {
// dependents[i] = current_dependents[actions[i].result_name];
// const bool compilable = isCompilable(*actions[i].function_base);
// for (const auto & name : actions[i].argument_names)
// {
// if (compilable)
// current_dependents[name].emplace(i);
// else
// current_dependents[name].emplace();
// }
// break;
// }
// }
// }
// return dependents;
//}
//
//void compileFunctions(
// ExpressionActions::Actions & actions,
// const Names & output_columns,
// const Block & sample_block,
// std::shared_ptr<CompiledExpressionCache> compilation_cache,
// size_t min_count_to_compile_expression)
//{
// static std::unordered_map<UInt128, UInt32, UInt128Hash> counter;
// static std::mutex mutex;
//
// struct LLVMTargetInitializer
// {
// LLVMTargetInitializer()
// {
// llvm::InitializeNativeTarget();
// llvm::InitializeNativeTargetAsmPrinter();
// llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
// }
// };
//
// static LLVMTargetInitializer initializer;
//
// auto dependents = getActionsDependents(actions, output_columns);
// std::vector<ExpressionActions::Actions> fused(actions.size());
// for (size_t i = 0; i < actions.size(); ++i)
// {
// if (actions[i].type != ExpressionAction::APPLY_FUNCTION || !isCompilable(*actions[i].function_base))
// continue;
//
// fused[i].push_back(actions[i]);
// if (dependents[i].find({}) != dependents[i].end())
// {
// /// the result of compiling one function in isolation is pretty much the same as its `execute` method.
// if (fused[i].size() == 1)
// continue;
//
// auto hash_key = ExpressionActions::ActionsHash{}(fused[i]);
// {
// std::lock_guard lock(mutex);
// if (counter[hash_key]++ < min_count_to_compile_expression)
// continue;
// }
//
// FunctionBasePtr fn;
// if (compilation_cache)
// {
// std::tie(fn, std::ignore) = compilation_cache->getOrSet(hash_key, [&inlined_func=std::as_const(fused[i]), &sample_block] ()
// {
// Stopwatch watch;
// FunctionBasePtr result_fn;
// result_fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(inlined_func, sample_block));
// ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
// return result_fn;
// });
// }
// else
// {
// Stopwatch watch;
// fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(fused[i], sample_block));
// ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
// }
//
// actions[i].function_base = fn;
// actions[i].argument_names = typeid_cast<const LLVMFunction *>(typeid_cast<const FunctionBaseAdaptor *>(fn.get())->getImpl())->getArgumentNames();
// actions[i].is_function_compiled = true;
//
// continue;
// }
//
// /// TODO: determine whether it's profitable to inline the function if there's more than one dependent.
// for (const auto & dep : dependents[i])
// fused[*dep].insert(fused[*dep].end(), fused[i].begin(), fused[i].end());
// }
//
// for (auto & action : actions)
// {
// if (action.type == ExpressionAction::APPLY_FUNCTION && action.is_function_compiled)
// action.function = action.function_base->prepare({}); /// Arguments are not used for LLVMFunction.
// }
//}
}

View File

@ -74,7 +74,7 @@ public:
/// For each APPLY_FUNCTION action, try to compile the function to native code; if the only uses of a compilable
/// function's result are as arguments to other compilable functions, inline it and leave the now-redundant action as-is.
void compileFunctions(ExpressionActions::Actions & actions, const Names & output_columns, const Block & sample_block, std::shared_ptr<CompiledExpressionCache> compilation_cache, size_t min_count_to_compile_expression);
// void compileFunctions(ExpressionActions::Actions & actions, const Names & output_columns, const Block & sample_block, std::shared_ptr<CompiledExpressionCache> compilation_cache, size_t min_count_to_compile_expression);
}

View File

@ -988,6 +988,10 @@ void HashJoin::joinBlockImpl(
const auto & right_key = required_right_keys.getByPosition(i);
const auto & left_name = required_right_keys_sources[i];
/// asof column is already in block.
if (is_asof_join && right_key.name == key_names_right.back())
continue;
const auto & col = block.getByName(left_name);
bool is_nullable = nullable_right_side || right_key.type->isNullable();
block.insert(correctNullability({col.column, col.type, right_key.name}, is_nullable));
@ -1007,6 +1011,10 @@ void HashJoin::joinBlockImpl(
const auto & right_key = required_right_keys.getByPosition(i);
const auto & left_name = required_right_keys_sources[i];
/// asof column is already in block.
if (is_asof_join && right_key.name == key_names_right.back())
continue;
const auto & col = block.getByName(left_name);
bool is_nullable = nullable_right_side || right_key.type->isNullable();

View File

@ -98,7 +98,7 @@ namespace ErrorCodes
/// Assumes `storage` is set and the table filter (row-level security) is not empty.
String InterpreterSelectQuery::generateFilterActions(
ExpressionActionsPtr & actions, const ASTPtr & row_policy_filter, const Names & prerequisite_columns) const
ActionsDAGPtr & actions, const ASTPtr & row_policy_filter, const Names & prerequisite_columns) const
{
const auto & db_name = table_id.getDatabaseName();
const auto & table_name = table_id.getTableName();
@ -393,7 +393,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
filter_info = std::make_shared<FilterInfo>();
filter_info->column_name = generateFilterActions(filter_info->actions, row_policy_filter, required_columns);
source_header = metadata_snapshot->getSampleBlockForColumns(
filter_info->actions->getRequiredColumns(), storage->getVirtuals(), storage->getStorageID());
filter_info->actions->getRequiredColumns().getNames(), storage->getVirtuals(), storage->getStorageID());
}
}
@ -520,7 +520,7 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
if (analysis_result.prewhere_info)
{
analysis_result.prewhere_info->prewhere_actions->execute(header);
analysis_result.prewhere_info->prewhere_actions->buildExpressions()->execute(header);
header = materializeBlock(header);
if (analysis_result.prewhere_info->remove_prewhere_column)
header.erase(analysis_result.prewhere_info->prewhere_column_name);
@ -531,9 +531,9 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
if (options.to_stage == QueryProcessingStage::Enum::WithMergeableState)
{
if (!analysis_result.need_aggregate)
return analysis_result.before_order_and_select->getSampleBlock();
return analysis_result.before_order_and_select->getResultColumns();
auto header = analysis_result.before_aggregation->getSampleBlock();
Block header = analysis_result.before_aggregation->getResultColumns();
Block res;
@ -557,10 +557,10 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
if (options.to_stage == QueryProcessingStage::Enum::WithMergeableStateAfterAggregation)
{
return analysis_result.before_order_and_select->getSampleBlock();
return analysis_result.before_order_and_select->getResultColumns();
}
return analysis_result.final_projection->getSampleBlock();
return analysis_result.final_projection->getResultColumns();
}
static Field getWithFillFieldValue(const ASTPtr & node, const Context & context)
@ -1108,7 +1108,7 @@ static StreamLocalLimits getLimitsForStorage(const Settings & settings, const Se
void InterpreterSelectQuery::executeFetchColumns(
QueryProcessingStage::Enum processing_stage, QueryPlan & query_plan,
const PrewhereInfoPtr & prewhere_info, const Names & columns_to_remove_after_prewhere)
const PrewhereInfoPtr & prewhere_info, const NameSet & columns_to_remove_after_prewhere)
{
auto & query = getSelectQuery();
const Settings & settings = context->getSettingsRef();
@ -1156,7 +1156,7 @@ void InterpreterSelectQuery::executeFetchColumns(
auto column = ColumnAggregateFunction::create(func);
column->insertFrom(place);
auto header = analysis_result.before_aggregation->getSampleBlock();
Block header = analysis_result.before_aggregation->getResultColumns();
size_t arguments_size = desc.argument_names.size();
DataTypes argument_types(arguments_size);
for (size_t j = 0; j < arguments_size; ++j)
@ -1176,7 +1176,7 @@ void InterpreterSelectQuery::executeFetchColumns(
}
/// Actions to calculate ALIAS if required.
ExpressionActionsPtr alias_actions;
ActionsDAGPtr alias_actions;
if (storage)
{
@ -1185,14 +1185,14 @@ void InterpreterSelectQuery::executeFetchColumns(
if (row_policy_filter)
{
auto initial_required_columns = required_columns;
ExpressionActionsPtr actions;
ActionsDAGPtr actions;
generateFilterActions(actions, row_policy_filter, initial_required_columns);
auto required_columns_from_filter = actions->getRequiredColumns();
for (const auto & column : required_columns_from_filter)
{
if (required_columns.end() == std::find(required_columns.begin(), required_columns.end(), column))
required_columns.push_back(column);
if (required_columns.end() == std::find(required_columns.begin(), required_columns.end(), column.name))
required_columns.push_back(column.name);
}
}
@ -1224,7 +1224,7 @@ void InterpreterSelectQuery::executeFetchColumns(
if (prewhere_info)
{
/// Get some columns directly from PREWHERE expression actions
auto prewhere_required_columns = prewhere_info->prewhere_actions->getRequiredColumns();
auto prewhere_required_columns = prewhere_info->prewhere_actions->getRequiredColumns().getNames();
required_columns_from_prewhere.insert(prewhere_required_columns.begin(), prewhere_required_columns.end());
}
@ -1270,7 +1270,7 @@ void InterpreterSelectQuery::executeFetchColumns(
if (prewhere_info)
{
NameSet columns_to_remove(columns_to_remove_after_prewhere.begin(), columns_to_remove_after_prewhere.end());
Block prewhere_actions_result = prewhere_info->prewhere_actions->getSampleBlock();
Block prewhere_actions_result = prewhere_info->prewhere_actions->getResultColumns();
/// Populate required columns with the columns, added by PREWHERE actions and not removed afterwards.
/// XXX: looks hacky that we already know which columns after PREWHERE we won't need for sure.
@ -1291,10 +1291,10 @@ void InterpreterSelectQuery::executeFetchColumns(
}
auto syntax_result = TreeRewriter(*context).analyze(required_columns_all_expr, required_columns_after_prewhere, storage, metadata_snapshot);
alias_actions = ExpressionAnalyzer(required_columns_all_expr, syntax_result, *context).getActions(true);
alias_actions = ExpressionAnalyzer(required_columns_all_expr, syntax_result, *context).getActionsDAG(true);
/// The set of required columns could be added as a result of adding an action to calculate ALIAS.
required_columns = alias_actions->getRequiredColumns();
required_columns = alias_actions->getRequiredColumns().getNames();
/// Do not remove prewhere filter if it is a column which is used as alias.
if (prewhere_info && prewhere_info->remove_prewhere_column)
@ -1311,27 +1311,21 @@ void InterpreterSelectQuery::executeFetchColumns(
if (prewhere_info)
{
/// Don't remove columns which are needed to be aliased.
auto new_actions = std::make_shared<ExpressionActions>(prewhere_info->prewhere_actions->getRequiredColumnsWithTypes(), *context);
for (const auto & action : prewhere_info->prewhere_actions->getActions())
{
if (action.type != ExpressionAction::REMOVE_COLUMN
|| required_columns.end() == std::find(required_columns.begin(), required_columns.end(), action.source_name))
new_actions->add(action);
}
prewhere_info->prewhere_actions = std::move(new_actions);
for (const auto & name : required_columns)
prewhere_info->prewhere_actions->tryRestoreColumn(name);
auto analyzed_result
= TreeRewriter(*context).analyze(required_columns_from_prewhere_expr, metadata_snapshot->getColumns().getAllPhysical());
prewhere_info->alias_actions
= ExpressionAnalyzer(required_columns_from_prewhere_expr, analyzed_result, *context).getActions(true, false);
= ExpressionAnalyzer(required_columns_from_prewhere_expr, analyzed_result, *context).getActionsDAG(true, false);
/// Add (physical?) columns required by alias actions.
auto required_columns_from_alias = prewhere_info->alias_actions->getRequiredColumns();
Block prewhere_actions_result = prewhere_info->prewhere_actions->getSampleBlock();
Block prewhere_actions_result = prewhere_info->prewhere_actions->getResultColumns();
for (auto & column : required_columns_from_alias)
if (!prewhere_actions_result.has(column))
if (required_columns.end() == std::find(required_columns.begin(), required_columns.end(), column))
required_columns.push_back(column);
if (!prewhere_actions_result.has(column.name))
if (required_columns.end() == std::find(required_columns.begin(), required_columns.end(), column.name))
required_columns.push_back(column.name);
/// Add physical columns required by prewhere actions.
for (const auto & column : required_columns_from_prewhere)
@ -1488,7 +1482,7 @@ void InterpreterSelectQuery::executeFetchColumns(
}
void InterpreterSelectQuery::executeWhere(QueryPlan & query_plan, const ExpressionActionsPtr & expression, bool remove_filter)
void InterpreterSelectQuery::executeWhere(QueryPlan & query_plan, const ActionsDAGPtr & expression, bool remove_filter)
{
auto where_step = std::make_unique<FilterStep>(
query_plan.getCurrentDataStream(),
@ -1501,7 +1495,7 @@ void InterpreterSelectQuery::executeWhere(QueryPlan & query_plan, const Expressi
}
void InterpreterSelectQuery::executeAggregation(QueryPlan & query_plan, const ExpressionActionsPtr & expression, bool overflow_row, bool final, InputOrderInfoPtr group_by_info)
void InterpreterSelectQuery::executeAggregation(QueryPlan & query_plan, const ActionsDAGPtr & expression, bool overflow_row, bool final, InputOrderInfoPtr group_by_info)
{
auto expression_before_aggregation = std::make_unique<ExpressionStep>(query_plan.getCurrentDataStream(), expression);
expression_before_aggregation->setStepDescription("Before GROUP BY");
@ -1598,7 +1592,7 @@ void InterpreterSelectQuery::executeMergeAggregated(QueryPlan & query_plan, bool
}
void InterpreterSelectQuery::executeHaving(QueryPlan & query_plan, const ExpressionActionsPtr & expression)
void InterpreterSelectQuery::executeHaving(QueryPlan & query_plan, const ActionsDAGPtr & expression)
{
auto having_step = std::make_unique<FilterStep>(
query_plan.getCurrentDataStream(),
@ -1609,7 +1603,7 @@ void InterpreterSelectQuery::executeHaving(QueryPlan & query_plan, const Express
}
void InterpreterSelectQuery::executeTotalsAndHaving(QueryPlan & query_plan, bool has_having, const ExpressionActionsPtr & expression, bool overflow_row, bool final)
void InterpreterSelectQuery::executeTotalsAndHaving(QueryPlan & query_plan, bool has_having, const ActionsDAGPtr & expression, bool overflow_row, bool final)
{
const Settings & settings = context->getSettingsRef();
@ -1651,7 +1645,7 @@ void InterpreterSelectQuery::executeRollupOrCube(QueryPlan & query_plan, Modific
}
void InterpreterSelectQuery::executeExpression(QueryPlan & query_plan, const ExpressionActionsPtr & expression, const std::string & description)
void InterpreterSelectQuery::executeExpression(QueryPlan & query_plan, const ActionsDAGPtr & expression, const std::string & description)
{
auto expression_step = std::make_unique<ExpressionStep>(query_plan.getCurrentDataStream(), expression);
@ -1742,7 +1736,7 @@ void InterpreterSelectQuery::executeMergeSorted(QueryPlan & query_plan, const So
}
void InterpreterSelectQuery::executeProjection(QueryPlan & query_plan, const ExpressionActionsPtr & expression)
void InterpreterSelectQuery::executeProjection(QueryPlan & query_plan, const ActionsDAGPtr & expression)
{
auto projection_step = std::make_unique<ExpressionStep>(query_plan.getCurrentDataStream(), expression);
projection_step->setStepDescription("Projection");

View File

@ -117,14 +117,14 @@ private:
QueryProcessingStage::Enum processing_stage,
QueryPlan & query_plan,
const PrewhereInfoPtr & prewhere_info,
const Names & columns_to_remove_after_prewhere);
const NameSet & columns_to_remove_after_prewhere);
void executeWhere(QueryPlan & query_plan, const ExpressionActionsPtr & expression, bool remove_filter);
void executeAggregation(QueryPlan & query_plan, const ExpressionActionsPtr & expression, bool overflow_row, bool final, InputOrderInfoPtr group_by_info);
void executeWhere(QueryPlan & query_plan, const ActionsDAGPtr & expression, bool remove_filter);
void executeAggregation(QueryPlan & query_plan, const ActionsDAGPtr & expression, bool overflow_row, bool final, InputOrderInfoPtr group_by_info);
void executeMergeAggregated(QueryPlan & query_plan, bool overflow_row, bool final);
void executeTotalsAndHaving(QueryPlan & query_plan, bool has_having, const ExpressionActionsPtr & expression, bool overflow_row, bool final);
void executeHaving(QueryPlan & query_plan, const ExpressionActionsPtr & expression);
static void executeExpression(QueryPlan & query_plan, const ExpressionActionsPtr & expression, const std::string & description);
void executeTotalsAndHaving(QueryPlan & query_plan, bool has_having, const ActionsDAGPtr & expression, bool overflow_row, bool final);
void executeHaving(QueryPlan & query_plan, const ActionsDAGPtr & expression);
static void executeExpression(QueryPlan & query_plan, const ActionsDAGPtr & expression, const std::string & description);
void executeOrder(QueryPlan & query_plan, InputOrderInfoPtr sorting_info);
void executeOrderOptimized(QueryPlan & query_plan, InputOrderInfoPtr sorting_info, UInt64 limit, SortDescription & output_order_descr);
void executeWithFill(QueryPlan & query_plan);
@ -133,14 +133,14 @@ private:
void executeLimitBy(QueryPlan & query_plan);
void executeLimit(QueryPlan & query_plan);
void executeOffset(QueryPlan & query_plan);
static void executeProjection(QueryPlan & query_plan, const ExpressionActionsPtr & expression);
static void executeProjection(QueryPlan & query_plan, const ActionsDAGPtr & expression);
void executeDistinct(QueryPlan & query_plan, bool before_order, Names columns, bool pre_distinct);
void executeExtremes(QueryPlan & query_plan);
void executeSubqueriesInSetsAndJoins(QueryPlan & query_plan, std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
void executeMergeSorted(QueryPlan & query_plan, const SortDescription & sort_description, UInt64 limit, const std::string & description);
String generateFilterActions(
ExpressionActionsPtr & actions, const ASTPtr & row_policy_filter, const Names & prerequisite_columns = {}) const;
ActionsDAGPtr & actions, const ASTPtr & row_policy_filter, const Names & prerequisite_columns = {}) const;
enum class Modificator
{

View File

@ -644,6 +644,10 @@ ASTPtr MutationsInterpreter::prepareInterpreterSelectQuery(std::vector<Stage> &
for (const auto & column_name : prepared_stages[0].output_columns)
select->select()->children.push_back(std::make_shared<ASTIdentifier>(column_name));
/// Don't let select list be empty.
if (select->select()->children.empty())
select->select()->children.push_back(std::make_shared<ASTLiteral>(Field(0)));
if (!prepared_stages[0].filters.empty())
{
ASTPtr where_expression;
@ -676,12 +680,12 @@ QueryPipelinePtr MutationsInterpreter::addStreamsForLaterStages(const std::vecto
if (i < stage.filter_column_names.size())
{
/// Execute DELETEs.
plan.addStep(std::make_unique<FilterStep>(plan.getCurrentDataStream(), step->getExpression(), stage.filter_column_names[i], false));
plan.addStep(std::make_unique<FilterStep>(plan.getCurrentDataStream(), step->actions(), stage.filter_column_names[i], false));
}
else
{
/// Execute UPDATE or final projection.
plan.addStep(std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), step->getExpression()));
plan.addStep(std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), step->actions()));
}
}

View File

@ -43,7 +43,7 @@ Block getHeaderForProcessingStage(
Block header = metadata_snapshot->getSampleBlockForColumns(column_names, storage.getVirtuals(), storage.getStorageID());
if (query_info.prewhere_info)
{
query_info.prewhere_info->prewhere_actions->execute(header);
query_info.prewhere_info->prewhere_actions->buildExpressions()->execute(header);
if (query_info.prewhere_info->remove_prewhere_column)
header.erase(query_info.prewhere_info->prewhere_column_name);
}

View File

@ -9,18 +9,18 @@
namespace DB
{
static ITransformingStep::Traits getTraits(const ExpressionActionsPtr & expression)
static ITransformingStep::Traits getTraits(const ActionsDAGPtr & actions)
{
return ITransformingStep::Traits
{
{
.preserves_distinct_columns = !expression->hasArrayJoin(),
.preserves_distinct_columns = !actions->hasArrayJoin(),
.returns_single_stream = false,
.preserves_number_of_streams = true,
.preserves_sorting = !expression->hasArrayJoin(),
.preserves_sorting = !actions->hasArrayJoin(),
},
{
.preserves_number_of_rows = !expression->hasArrayJoin(),
.preserves_number_of_rows = !actions->hasArrayJoin(),
}
};
}
@ -41,12 +41,12 @@ static ITransformingStep::Traits getJoinTraits()
};
}
ExpressionStep::ExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_)
ExpressionStep::ExpressionStep(const DataStream & input_stream_, ActionsDAGPtr actions_)
: ITransformingStep(
input_stream_,
Transform::transformHeader(input_stream_.header, expression_),
getTraits(expression_))
, expression(std::move(expression_))
Transform::transformHeader(input_stream_.header, actions_->buildExpressions()),
getTraits(actions_))
, actions(std::move(actions_))
{
/// Some columns may be removed by expression.
updateDistinctColumns(output_stream->header, output_stream->distinct_columns);
@ -55,7 +55,7 @@ ExpressionStep::ExpressionStep(const DataStream & input_stream_, ExpressionActio
void ExpressionStep::updateInputStream(DataStream input_stream, bool keep_header)
{
Block out_header = keep_header ? std::move(output_stream->header)
: Transform::transformHeader(input_stream.header, expression);
: Transform::transformHeader(input_stream.header, actions->buildExpressions());
output_stream = createOutputStream(
input_stream,
std::move(out_header),
@ -67,6 +67,7 @@ void ExpressionStep::updateInputStream(DataStream input_stream, bool keep_header
void ExpressionStep::transformPipeline(QueryPipeline & pipeline)
{
auto expression = actions->buildExpressions();
pipeline.addSimpleTransform([&](const Block & header)
{
return std::make_shared<Transform>(header, expression);
@ -82,11 +83,12 @@ void ExpressionStep::transformPipeline(QueryPipeline & pipeline)
}
}
static void doDescribeActions(const ExpressionActionsPtr & expression, IQueryPlanStep::FormatSettings & settings)
void ExpressionStep::describeActions(FormatSettings & settings) const
{
String prefix(settings.offset, ' ');
bool first = true;
auto expression = actions->buildExpressions();
for (const auto & action : expression->getActions())
{
settings.out << prefix << (first ? "Actions: "
@ -96,11 +98,6 @@ static void doDescribeActions(const ExpressionActionsPtr & expression, IQueryPla
}
}
void ExpressionStep::describeActions(FormatSettings & settings) const
{
doDescribeActions(expression, settings);
}
JoinStep::JoinStep(const DataStream & input_stream_, JoinPtr join_)
: ITransformingStep(
input_stream_,

View File

@ -4,8 +4,8 @@
namespace DB
{
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ActionsDAG;
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
class IJoin;
using JoinPtr = std::shared_ptr<IJoin>;
@ -19,7 +19,7 @@ class ExpressionStep : public ITransformingStep
public:
using Transform = ExpressionTransform;
explicit ExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_);
explicit ExpressionStep(const DataStream & input_stream_, ActionsDAGPtr actions_);
String getName() const override { return "Expression"; }
void transformPipeline(QueryPipeline & pipeline) override;
@ -28,10 +28,10 @@ public:
void describeActions(FormatSettings & settings) const override;
const ExpressionActionsPtr & getExpression() const { return expression; }
const ActionsDAGPtr & getExpression() const { return actions; }
private:
ExpressionActionsPtr expression;
ActionsDAGPtr actions;
};
/// TODO: add separate step for join.

View File

@ -28,7 +28,7 @@ static ITransformingStep::Traits getTraits()
}
FillingStep::FillingStep(const DataStream & input_stream_, SortDescription sort_description_)
: ITransformingStep(input_stream_, input_stream_.header, getTraits())
: ITransformingStep(input_stream_, FillingTransform::transformHeader(input_stream_.header, sort_description_), getTraits())
, sort_description(std::move(sort_description_))
{
if (!input_stream_.has_single_port)

View File

@ -8,7 +8,7 @@
namespace DB
{
static ITransformingStep::Traits getTraits(const ExpressionActionsPtr & expression)
static ITransformingStep::Traits getTraits(const ActionsDAGPtr & expression)
{
return ITransformingStep::Traits
{
@ -26,14 +26,14 @@ static ITransformingStep::Traits getTraits(const ExpressionActionsPtr & expressi
FilterStep::FilterStep(
const DataStream & input_stream_,
ExpressionActionsPtr expression_,
ActionsDAGPtr actions_,
String filter_column_name_,
bool remove_filter_column_)
: ITransformingStep(
input_stream_,
FilterTransform::transformHeader(input_stream_.header, expression_, filter_column_name_, remove_filter_column_),
getTraits(expression_))
, expression(std::move(expression_))
FilterTransform::transformHeader(input_stream_.header, actions_->buildExpressions(), filter_column_name_, remove_filter_column_),
getTraits(actions_))
, actions(std::move(actions_))
, filter_column_name(std::move(filter_column_name_))
, remove_filter_column(remove_filter_column_)
{
@ -45,7 +45,7 @@ void FilterStep::updateInputStream(DataStream input_stream, bool keep_header)
{
Block out_header = std::move(output_stream->header);
if (keep_header)
out_header = FilterTransform::transformHeader(input_stream.header, expression, filter_column_name, remove_filter_column);
out_header = FilterTransform::transformHeader(input_stream.header, actions->buildExpressions(), filter_column_name, remove_filter_column);
output_stream = createOutputStream(
input_stream,
@ -58,6 +58,7 @@ void FilterStep::updateInputStream(DataStream input_stream, bool keep_header)
void FilterStep::transformPipeline(QueryPipeline & pipeline)
{
auto expression = actions->buildExpressions();
pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type)
{
bool on_totals = stream_type == QueryPipeline::StreamType::Totals;
@ -79,6 +80,7 @@ void FilterStep::describeActions(FormatSettings & settings) const
settings.out << prefix << "Filter column: " << filter_column_name << '\n';
bool first = true;
auto expression = actions->buildExpressions();
for (const auto & action : expression->getActions())
{
settings.out << prefix << (first ? "Actions: "

View File

@ -4,8 +4,8 @@
namespace DB
{
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ActionsDAG;
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
/// Implements WHERE, HAVING operations. See FilterTransform.
class FilterStep : public ITransformingStep
@ -13,7 +13,7 @@ class FilterStep : public ITransformingStep
public:
FilterStep(
const DataStream & input_stream_,
ExpressionActionsPtr expression_,
ActionsDAGPtr actions_,
String filter_column_name_,
bool remove_filter_column_);
@ -24,12 +24,12 @@ public:
void describeActions(FormatSettings & settings) const override;
const ExpressionActionsPtr & getExpression() const { return expression; }
const ActionsDAGPtr & getExpression() const { return actions; }
const String & getFilterColumnName() const { return filter_column_name; }
bool removesFilterColumn() const { return remove_filter_column; }
private:
ExpressionActionsPtr expression;
ActionsDAGPtr actions;
String filter_column_name;
bool remove_filter_column;
};

View File

@ -438,7 +438,7 @@ static void tryLiftUpArrayJoin(QueryPlan::Node * parent_node, QueryPlan::Node *
return;
/// All actions was moved before ARRAY JOIN. Swap Expression and ArrayJoin.
if (expression->getActions().empty())
if (expression->empty())
{
auto expected_header = parent->getOutputStream().header;

View File

@ -38,17 +38,19 @@ ReadFromStorageStep::ReadFromStorageStep(
{
if (query_info.prewhere_info->alias_actions)
{
auto alias_actions = query_info.prewhere_info->alias_actions->buildExpressions();
pipe.addSimpleTransform([&](const Block & header)
{
return std::make_shared<ExpressionTransform>(header, query_info.prewhere_info->alias_actions);
return std::make_shared<ExpressionTransform>(header, alias_actions);
});
}
auto prewhere_actions = query_info.prewhere_info->prewhere_actions->buildExpressions();
pipe.addSimpleTransform([&](const Block & header)
{
return std::make_shared<FilterTransform>(
header,
query_info.prewhere_info->prewhere_actions,
prewhere_actions,
query_info.prewhere_info->prewhere_column_name,
query_info.prewhere_info->remove_prewhere_column);
});
@ -59,10 +61,10 @@ ReadFromStorageStep::ReadFromStorageStep(
// This leads to mismatched header in distributed table
if (query_info.prewhere_info->remove_columns_actions)
{
auto remove_actions = query_info.prewhere_info->remove_columns_actions->buildExpressions();
pipe.addSimpleTransform([&](const Block & header)
{
return std::make_shared<ExpressionTransform>(
header, query_info.prewhere_info->remove_columns_actions);
return std::make_shared<ExpressionTransform>(header, remove_actions);
});
}
}

View File

@ -27,17 +27,17 @@ static ITransformingStep::Traits getTraits(bool has_filter)
TotalsHavingStep::TotalsHavingStep(
const DataStream & input_stream_,
bool overflow_row_,
const ExpressionActionsPtr & expression_,
const ActionsDAGPtr & actions_,
const std::string & filter_column_,
TotalsMode totals_mode_,
double auto_include_threshold_,
bool final_)
: ITransformingStep(
input_stream_,
TotalsHavingTransform::transformHeader(input_stream_.header, expression_, final_),
TotalsHavingTransform::transformHeader(input_stream_.header, (actions_ ? actions_->buildExpressions() : nullptr), final_),
getTraits(!filter_column_.empty()))
, overflow_row(overflow_row_)
, expression(expression_)
, actions(actions_)
, filter_column_name(filter_column_)
, totals_mode(totals_mode_)
, auto_include_threshold(auto_include_threshold_)
@ -48,7 +48,7 @@ TotalsHavingStep::TotalsHavingStep(
void TotalsHavingStep::transformPipeline(QueryPipeline & pipeline)
{
auto totals_having = std::make_shared<TotalsHavingTransform>(
pipeline.getHeader(), overflow_row, expression,
pipeline.getHeader(), overflow_row, (actions ? actions->buildExpressions() : nullptr),
filter_column_name, totals_mode, auto_include_threshold, final);
pipeline.addTotalsHavingTransform(std::move(totals_having));
@ -78,6 +78,7 @@ void TotalsHavingStep::describeActions(FormatSettings & settings) const
settings.out << prefix << "Mode: " << totalsModeToString(totals_mode, auto_include_threshold) << '\n';
bool first = true;
auto expression = actions->buildExpressions();
for (const auto & action : expression->getActions())
{
settings.out << prefix << (first ? "Actions: "

View File

@ -4,8 +4,8 @@
namespace DB
{
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ActionsDAG;
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
enum class TotalsMode;
@ -16,7 +16,7 @@ public:
TotalsHavingStep(
const DataStream & input_stream_,
bool overflow_row_,
const ExpressionActionsPtr & expression_,
const ActionsDAGPtr & actions_,
const std::string & filter_column_,
TotalsMode totals_mode_,
double auto_include_threshold_,
@ -30,7 +30,7 @@ public:
private:
bool overflow_row;
ExpressionActionsPtr expression;
ActionsDAGPtr actions;
String filter_column_name;
TotalsMode totals_mode;
double auto_include_threshold;

View File

@ -1,13 +1,12 @@
#include <Processors/Transforms/ExpressionTransform.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h>
namespace DB
{
Block ExpressionTransform::transformHeader(Block header, const ExpressionActionsPtr & expression)
{
expression->execute(header, true);
size_t num_rows = header.rows();
expression->execute(header, num_rows, true);
return header;
}
@ -20,11 +19,11 @@ ExpressionTransform::ExpressionTransform(const Block & header_, ExpressionAction
void ExpressionTransform::transform(Chunk & chunk)
{
size_t num_rows = chunk.getNumRows();
auto block = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns());
expression->execute(block);
expression->execute(block, num_rows);
auto num_rows = block.rows();
chunk.setColumns(block.getColumns(), num_rows);
}

View File

@ -10,10 +10,23 @@ namespace ErrorCodes
extern const int INVALID_WITH_FILL_EXPRESSION;
}
Block FillingTransform::transformHeader(Block header, const SortDescription & sort_description)
{
NameSet sort_keys;
for (const auto & key : sort_description)
sort_keys.insert(key.column_name);
/// Columns which are not from sorting key may not be constant anymore.
for (auto & column : header)
if (column.column && isColumnConst(*column.column) && !sort_keys.count(column.name))
column.column = column.type->createColumn();
return header;
}
FillingTransform::FillingTransform(
const Block & header_, const SortDescription & sort_description_)
: ISimpleTransform(header_, header_, true)
: ISimpleTransform(header_, transformHeader(header_, sort_description_), true)
, sort_description(sort_description_)
, filling_row(sort_description_)
, next_row(sort_description_)

View File

@ -19,6 +19,8 @@ public:
Status prepare() override;
static Block transformHeader(Block header, const SortDescription & sort_description);
protected:
void transform(Chunk & Chunk) override;

View File

@ -33,7 +33,8 @@ Block FilterTransform::transformHeader(
const String & filter_column_name,
bool remove_filter_column)
{
expression->execute(header);
size_t num_rows = header.rows();
expression->execute(header, num_rows);
if (remove_filter_column)
header.erase(filter_column_name);
@ -96,16 +97,15 @@ void FilterTransform::removeFilterIfNeed(Chunk & chunk) const
void FilterTransform::transform(Chunk & chunk)
{
size_t num_rows_before_filtration;
size_t num_rows_before_filtration = chunk.getNumRows();
auto columns = chunk.detachColumns();
{
Block block = getInputPort().getHeader().cloneWithColumns(columns);
columns.clear();
expression->execute(block);
expression->execute(block, num_rows_before_filtration);
num_rows_before_filtration = block.rows();
columns = block.getColumns();
}

View File

@ -32,8 +32,10 @@ Block TotalsHavingTransform::transformHeader(Block block, const ExpressionAction
if (final)
finalizeBlock(block);
size_t num_rows = block.rows();
if (expression)
expression->execute(block);
expression->execute(block, num_rows);
return block;
}
@ -64,7 +66,8 @@ TotalsHavingTransform::TotalsHavingTransform(
if (expression)
{
auto totals_header = finalized_header;
expression->execute(totals_header);
size_t num_rows = totals_header.rows();
expression->execute(totals_header, num_rows);
outputs.emplace_back(totals_header, this);
}
else
@ -155,8 +158,9 @@ void TotalsHavingTransform::transform(Chunk & chunk)
{
/// Compute the expression in HAVING.
const auto & cur_header = final ? finalized_header : getInputPort().getHeader();
size_t num_rows = finalized.getNumRows();
auto finalized_block = cur_header.cloneWithColumns(finalized.detachColumns());
expression->execute(finalized_block);
expression->execute(finalized_block, num_rows);
auto columns = finalized_block.getColumns();
ColumnPtr filter_column_ptr = columns[filter_column_pos];
@ -165,7 +169,6 @@ void TotalsHavingTransform::transform(Chunk & chunk)
if (const_filter_description.always_true)
{
addToTotals(chunk, nullptr);
auto num_rows = columns.front()->size();
chunk.setColumns(std::move(columns), num_rows);
return;
}
@ -198,7 +201,7 @@ void TotalsHavingTransform::transform(Chunk & chunk)
}
}
auto num_rows = columns.front()->size();
num_rows = columns.front()->size();
chunk.setColumns(std::move(columns), num_rows);
}
@ -255,10 +258,11 @@ void TotalsHavingTransform::prepareTotals()
if (expression)
{
size_t num_rows = totals.getNumRows();
auto block = finalized_header.cloneWithColumns(totals.detachColumns());
expression->execute(block);
expression->execute(block, num_rows);
/// Note: after expression totals may have several rows if `arrayJoin` was used in expression.
totals = Chunk(block.getColumns(), block.rows());
totals = Chunk(block.getColumns(), num_rows);
}
}

View File

@ -502,7 +502,7 @@ Block validateColumnsDefaultsAndGetSampleBlock(ASTPtr default_expr_list, const N
auto syntax_analyzer_result = TreeRewriter(context).analyze(default_expr_list, all_columns);
const auto actions = ExpressionAnalyzer(default_expr_list, syntax_analyzer_result, context).getActions(true);
for (const auto & action : actions->getActions())
if (action.type == ExpressionAction::Type::ARRAY_JOIN)
if (action.node->type == ActionsDAG::Type::ARRAY_JOIN)
throw Exception("Unsupported default value that requires ARRAY JOIN action", ErrorCodes::THERE_IS_NO_DEFAULT_VALUE);
return actions->getSampleBlock();

View File

@ -31,7 +31,7 @@ IndexDescription::IndexDescription(const IndexDescription & other)
, granularity(other.granularity)
{
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
}
@ -54,7 +54,7 @@ IndexDescription & IndexDescription::operator=(const IndexDescription & other)
type = other.type;
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
else
expression.reset();

View File

@ -25,7 +25,7 @@ KeyDescription::KeyDescription(const KeyDescription & other)
, additional_column(other.additional_column)
{
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
}
KeyDescription & KeyDescription::operator=(const KeyDescription & other)
@ -45,7 +45,7 @@ KeyDescription & KeyDescription::operator=(const KeyDescription & other)
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
else
expression.reset();

View File

@ -569,7 +569,7 @@ bool KeyCondition::canConstantBeWrappedByMonotonicFunctions(
return false;
bool found_transformation = false;
for (const ExpressionAction & action : key_expr->getActions())
for (const auto & action : key_expr->getActions())
{
/** The key functional expression constraint may be inferred from a plain column in the expression.
* For example, if the key contains `toStartOfHour(Timestamp)` and query contains `WHERE Timestamp >= now()`,
@ -581,25 +581,25 @@ bool KeyCondition::canConstantBeWrappedByMonotonicFunctions(
* Instead, we can qualify only functions that do not transform the range (for example rounding),
* which while not strictly monotonic, are monotonic everywhere on the input range.
*/
const auto & argument_names = action.argument_names;
if (action.type == ExpressionAction::Type::APPLY_FUNCTION
&& argument_names.size() == 1
&& argument_names[0] == expr_name)
const auto & children = action.node->children;
if (action.node->type == ActionsDAG::Type::FUNCTION
&& children.size() == 1
&& children[0]->result_name == expr_name)
{
if (!action.function_base->hasInformationAboutMonotonicity())
if (!action.node->function_base->hasInformationAboutMonotonicity())
return false;
/// Range is irrelevant in this case.
IFunction::Monotonicity monotonicity = action.function_base->getMonotonicityForRange(*out_type, Field(), Field());
IFunction::Monotonicity monotonicity = action.node->function_base->getMonotonicityForRange(*out_type, Field(), Field());
if (!monotonicity.is_always_monotonic)
return false;
/// Apply the next transformation step.
std::tie(out_value, out_type) = applyFunctionForFieldOfUnknownType(
action.function_builder,
action.node->function_builder,
out_type, out_value);
expr_name = action.result_name;
expr_name = action.node->result_name;
/// Transformation results in a key expression, accept.
auto it = key_columns.find(expr_name);

View File

@ -40,6 +40,13 @@ MergeTreeBaseSelectProcessor::MergeTreeBaseSelectProcessor(
, use_uncompressed_cache(use_uncompressed_cache_)
, virt_column_names(virt_column_names_)
{
if (prewhere_info)
{
if (prewhere_info->alias_actions)
prewhere_alias_actions = prewhere_info->alias_actions->buildExpressions();
prewhere_actions = prewhere_info->prewhere_actions->buildExpressions();
}
header_without_virtual_columns = getPort().getHeader();
for (auto it = virt_column_names.rbegin(); it != virt_column_names.rend(); ++it)
@ -74,23 +81,39 @@ void MergeTreeBaseSelectProcessor::initializeRangeReaders(MergeTreeReadTask & cu
{
if (reader->getColumns().empty())
{
current_task.range_reader = MergeTreeRangeReader(pre_reader.get(), nullptr, prewhere_info, true);
current_task.range_reader = MergeTreeRangeReader(
pre_reader.get(), nullptr,
prewhere_alias_actions,
prewhere_actions,
prewhere_info->prewhere_column_name,
prewhere_info->remove_prewhere_column,
prewhere_info->need_filter,
true);
}
else
{
MergeTreeRangeReader * pre_reader_ptr = nullptr;
if (pre_reader != nullptr)
{
current_task.pre_range_reader = MergeTreeRangeReader(pre_reader.get(), nullptr, prewhere_info, false);
current_task.pre_range_reader = MergeTreeRangeReader(
pre_reader.get(), nullptr,
prewhere_alias_actions,
prewhere_actions,
prewhere_info->prewhere_column_name,
prewhere_info->remove_prewhere_column,
prewhere_info->need_filter,
false);
pre_reader_ptr = &current_task.pre_range_reader;
}
current_task.range_reader = MergeTreeRangeReader(reader.get(), pre_reader_ptr, nullptr, true);
current_task.range_reader = MergeTreeRangeReader(
reader.get(), pre_reader_ptr, nullptr, nullptr, {}, false, false, true);
}
}
else
{
current_task.range_reader = MergeTreeRangeReader(reader.get(), nullptr, nullptr, true);
current_task.range_reader = MergeTreeRangeReader(
reader.get(), nullptr, nullptr, nullptr, {}, false, false, true);
}
}
@ -314,9 +337,9 @@ void MergeTreeBaseSelectProcessor::executePrewhereActions(Block & block, const P
if (prewhere_info)
{
if (prewhere_info->alias_actions)
prewhere_info->alias_actions->execute(block);
prewhere_info->alias_actions->buildExpressions()->execute(block);
prewhere_info->prewhere_actions->execute(block);
prewhere_info->prewhere_actions->buildExpressions()->execute(block);
auto & prewhere_column = block.getByName(prewhere_info->prewhere_column_name);
if (!prewhere_column.type->canBeUsedInBooleanContext())

View File

@ -58,6 +58,8 @@ protected:
StorageMetadataPtr metadata_snapshot;
PrewhereInfoPtr prewhere_info;
ExpressionActionsPtr prewhere_alias_actions;
ExpressionActionsPtr prewhere_actions;
UInt64 max_block_size_rows;
UInt64 preferred_block_size_bytes;

View File

@ -260,9 +260,9 @@ MergeTreeReadTaskColumns getReadTaskColumns(
if (prewhere_info)
{
if (prewhere_info->alias_actions)
pre_column_names = prewhere_info->alias_actions->getRequiredColumns();
pre_column_names = prewhere_info->alias_actions->getRequiredColumns().getNames();
else
pre_column_names = prewhere_info->prewhere_actions->getRequiredColumns();
pre_column_names = prewhere_info->prewhere_actions->getRequiredColumns().getNames();
if (pre_column_names.empty())
pre_column_names.push_back(column_names[0]);

View File

@ -256,14 +256,14 @@ StoragePolicyPtr MergeTreeData::getStoragePolicy() const
static void checkKeyExpression(const ExpressionActions & expr, const Block & sample_block, const String & key_name, bool allow_nullable_key)
{
for (const ExpressionAction & action : expr.getActions())
for (const auto & action : expr.getActions())
{
if (action.type == ExpressionAction::ARRAY_JOIN)
if (action.node->type == ActionsDAG::Type::ARRAY_JOIN)
throw Exception(key_name + " key cannot contain array joins", ErrorCodes::ILLEGAL_COLUMN);
if (action.type == ExpressionAction::APPLY_FUNCTION)
if (action.node->type == ActionsDAG::Type::FUNCTION)
{
IFunctionBase & func = *action.function_base;
IFunctionBase & func = *action.node->function_base;
if (!func.isDeterministic())
throw Exception(key_name + " key cannot contain non-deterministic functions, "
"but contains function " + func.getName(),
@ -437,7 +437,7 @@ void MergeTreeData::checkPartitionKeyAndInitMinMax(const KeyDescription & new_pa
/// Add all columns used in the partition key to the min-max index.
const NamesAndTypesList & minmax_idx_columns_with_types = new_partition_key.expression->getRequiredColumnsWithTypes();
minmax_idx_expr = std::make_shared<ExpressionActions>(minmax_idx_columns_with_types, global_context);
minmax_idx_expr = std::make_shared<ActionsDAG>(minmax_idx_columns_with_types)->buildExpressions();
for (const NameAndTypePair & column : minmax_idx_columns_with_types)
{
minmax_idx_columns.emplace_back(column.name);
@ -1401,10 +1401,10 @@ void MergeTreeData::checkAlterIsPossible(const AlterCommands & commands, const S
{
/// Forbid altering columns inside partition key expressions because it can change partition ID format.
auto partition_key_expr = old_metadata.getPartitionKey().expression;
for (const ExpressionAction & action : partition_key_expr->getActions())
for (const auto & action : partition_key_expr->getActions())
{
auto action_columns = action.getNeededColumns();
columns_alter_type_forbidden.insert(action_columns.begin(), action_columns.end());
for (const auto * child : action.node->children)
columns_alter_type_forbidden.insert(child->result_name);
}
/// But allow to alter columns without expressions under certain condition.
@ -1421,10 +1421,10 @@ void MergeTreeData::checkAlterIsPossible(const AlterCommands & commands, const S
if (old_metadata.hasSortingKey())
{
auto sorting_key_expr = old_metadata.getSortingKey().expression;
for (const ExpressionAction & action : sorting_key_expr->getActions())
for (const auto & action : sorting_key_expr->getActions())
{
auto action_columns = action.getNeededColumns();
columns_alter_type_forbidden.insert(action_columns.begin(), action_columns.end());
for (const auto * child : action.node->children)
columns_alter_type_forbidden.insert(child->result_name);
}
for (const String & col : sorting_key_expr->getRequiredColumns())
columns_alter_type_metadata_only.insert(col);

View File

@ -706,7 +706,7 @@ Pipe MergeTreeDataSelectExecutor::readFromParts(
/// Projection, that needed to drop columns, which have appeared by execution
/// of some extra expressions, and to allow execute the same expressions later.
/// NOTE: It may lead to double computation of expressions.
ExpressionActionsPtr result_projection;
ActionsDAGPtr result_projection;
if (select.final())
{
@ -784,9 +784,10 @@ Pipe MergeTreeDataSelectExecutor::readFromParts(
if (result_projection)
{
res.addSimpleTransform([&result_projection](const Block & header)
auto result_projection_actions = result_projection->buildExpressions();
res.addSimpleTransform([&result_projection_actions](const Block & header)
{
return std::make_shared<ExpressionTransform>(header, result_projection);
return std::make_shared<ExpressionTransform>(header, result_projection_actions);
});
}
@ -802,9 +803,10 @@ Pipe MergeTreeDataSelectExecutor::readFromParts(
if (query_info.prewhere_info && query_info.prewhere_info->remove_columns_actions)
{
res.addSimpleTransform([&query_info](const Block & header)
auto remove_actions = query_info.prewhere_info->remove_columns_actions->buildExpressions();
res.addSimpleTransform([&remove_actions](const Block & header)
{
return std::make_shared<ExpressionTransform>(header, query_info.prewhere_info->remove_columns_actions);
return std::make_shared<ExpressionTransform>(header, remove_actions);
});
}
@ -956,11 +958,12 @@ Pipe MergeTreeDataSelectExecutor::spreadMarkRangesAmongStreams(
}
}
static ExpressionActionsPtr createProjection(const Pipe & pipe, const MergeTreeData & data)
static ActionsDAGPtr createProjection(const Pipe & pipe)
{
const auto & header = pipe.getHeader();
auto projection = std::make_shared<ExpressionActions>(header.getNamesAndTypesList(), data.global_context);
projection->add(ExpressionAction::project(header.getNames()));
auto projection = std::make_shared<ActionsDAG>(header.getNamesAndTypesList());
projection->removeUnusedActions(header.getNames());
projection->projectInput();
return projection;
}
@ -976,7 +979,7 @@ Pipe MergeTreeDataSelectExecutor::spreadMarkRangesAmongStreamsWithOrder(
const Names & virt_columns,
const Settings & settings,
const MergeTreeReaderSettings & reader_settings,
ExpressionActionsPtr & out_projection) const
ActionsDAGPtr & out_projection) const
{
size_t sum_marks = 0;
const InputOrderInfoPtr & input_order_info = query_info.input_order_info;
@ -1182,7 +1185,7 @@ Pipe MergeTreeDataSelectExecutor::spreadMarkRangesAmongStreamsWithOrder(
input_order_info->direction, 1);
/// Drop temporary columns, added by 'sorting_key_prefix_expr'
out_projection = createProjection(pipe, data);
out_projection = createProjection(pipe);
pipe.addSimpleTransform([sorting_key_prefix_expr](const Block & header)
{
return std::make_shared<ExpressionTransform>(header, sorting_key_prefix_expr);
@ -1210,7 +1213,7 @@ Pipe MergeTreeDataSelectExecutor::spreadMarkRangesAmongStreamsFinal(
const Names & virt_columns,
const Settings & settings,
const MergeTreeReaderSettings & reader_settings,
ExpressionActionsPtr & out_projection) const
ActionsDAGPtr & out_projection) const
{
const auto data_settings = data.getSettings();
size_t sum_marks = 0;
@ -1259,7 +1262,7 @@ Pipe MergeTreeDataSelectExecutor::spreadMarkRangesAmongStreamsFinal(
/// Drop temporary columns, added by 'sorting_key_expr'
if (!out_projection)
out_projection = createProjection(pipe, data);
out_projection = createProjection(pipe);
pipe.addSimpleTransform([&metadata_snapshot](const Block & header)
{

View File

@ -73,7 +73,7 @@ private:
const Names & virt_columns,
const Settings & settings,
const MergeTreeReaderSettings & reader_settings,
ExpressionActionsPtr & out_projection) const;
ActionsDAGPtr & out_projection) const;
Pipe spreadMarkRangesAmongStreamsFinal(
RangesInDataParts && parts,
@ -86,7 +86,7 @@ private:
const Names & virt_columns,
const Settings & settings,
const MergeTreeReaderSettings & reader_settings,
ExpressionActionsPtr & out_projection) const;
ActionsDAGPtr & out_projection) const;
/// Get the approximate value (bottom estimate - only by full marks) of the number of rows falling under the index.
size_t getApproximateTotalRowsToRead(

View File

@ -489,11 +489,20 @@ size_t MergeTreeRangeReader::ReadResult::countBytesInResultFilter(const IColumn:
MergeTreeRangeReader::MergeTreeRangeReader(
IMergeTreeReader * merge_tree_reader_,
MergeTreeRangeReader * prev_reader_,
const PrewhereInfoPtr & prewhere_,
ExpressionActionsPtr prewhere_alias_actions_,
ExpressionActionsPtr prewhere_actions_,
String prewhere_column_name_,
bool remove_prewhere_column_,
bool prewhere_need_filter_,
bool last_reader_in_chain_)
: merge_tree_reader(merge_tree_reader_)
, index_granularity(&(merge_tree_reader->data_part->index_granularity)), prev_reader(prev_reader_)
, prewhere(prewhere_), last_reader_in_chain(last_reader_in_chain_), is_initialized(true)
, prewhere_alias_actions(std::move(prewhere_alias_actions_))
, prewhere_actions(std::move(prewhere_actions_))
, prewhere_column_name(std::move(prewhere_column_name_))
, remove_prewhere_column(remove_prewhere_column_)
, prewhere_need_filter(prewhere_need_filter_)
, last_reader_in_chain(last_reader_in_chain_), is_initialized(true)
{
if (prev_reader)
sample_block = prev_reader->getSampleBlock();
@ -501,16 +510,16 @@ MergeTreeRangeReader::MergeTreeRangeReader(
for (const auto & name_and_type : merge_tree_reader->getColumns())
sample_block.insert({name_and_type.type->createColumn(), name_and_type.type, name_and_type.name});
if (prewhere)
if (prewhere_actions)
{
if (prewhere->alias_actions)
prewhere->alias_actions->execute(sample_block, true);
if (prewhere_alias_actions)
prewhere_alias_actions->execute(sample_block, true);
if (prewhere->prewhere_actions)
prewhere->prewhere_actions->execute(sample_block, true);
if (prewhere_actions)
prewhere_actions->execute(sample_block, true);
if (prewhere->remove_prewhere_column)
sample_block.erase(prewhere->prewhere_column_name);
if (remove_prewhere_column)
sample_block.erase(prewhere_column_name);
}
}
@ -794,7 +803,7 @@ Columns MergeTreeRangeReader::continueReadingChain(ReadResult & result, size_t &
void MergeTreeRangeReader::executePrewhereActionsAndFilterColumns(ReadResult & result)
{
if (!prewhere)
if (!prewhere_actions)
return;
const auto & header = merge_tree_reader->getColumns();
@ -825,14 +834,14 @@ void MergeTreeRangeReader::executePrewhereActionsAndFilterColumns(ReadResult & r
for (auto name_and_type = header.begin(); pos < num_columns; ++pos, ++name_and_type)
block.insert({result.columns[pos], name_and_type->type, name_and_type->name});
if (prewhere->alias_actions)
prewhere->alias_actions->execute(block);
if (prewhere_alias_actions)
prewhere_alias_actions->execute(block);
/// Columns might be projected out. We need to store them here so that default columns can be evaluated later.
result.block_before_prewhere = block;
prewhere->prewhere_actions->execute(block);
prewhere_actions->execute(block);
prewhere_column_pos = block.getPositionByName(prewhere->prewhere_column_name);
prewhere_column_pos = block.getPositionByName(prewhere_column_name);
result.columns.clear();
result.columns.reserve(block.columns());
@ -860,7 +869,7 @@ void MergeTreeRangeReader::executePrewhereActionsAndFilterColumns(ReadResult & r
if (result.totalRowsPerGranule() == 0)
result.setFilterConstFalse();
/// If we need to filter in PREWHERE
else if (prewhere->need_filter || result.need_filter)
else if (prewhere_need_filter || result.need_filter)
{
/// If there is a filter and without optimized
if (result.getFilter() && last_reader_in_chain)
@ -901,11 +910,11 @@ void MergeTreeRangeReader::executePrewhereActionsAndFilterColumns(ReadResult & r
/// Check if the PREWHERE column is needed
if (!result.columns.empty())
{
if (prewhere->remove_prewhere_column)
if (remove_prewhere_column)
result.columns.erase(result.columns.begin() + prewhere_column_pos);
else
result.columns[prewhere_column_pos] =
getSampleBlock().getByName(prewhere->prewhere_column_name).type->
getSampleBlock().getByName(prewhere_column_name).type->
createColumnConst(result.num_rows, 1u)->convertToFullColumnIfConst();
}
}
@ -913,7 +922,7 @@ void MergeTreeRangeReader::executePrewhereActionsAndFilterColumns(ReadResult & r
else
{
result.columns[prewhere_column_pos] = result.getFilterHolder()->convertToFullColumnIfConst();
if (getSampleBlock().getByName(prewhere->prewhere_column_name).type->isNullable())
if (getSampleBlock().getByName(prewhere_column_name).type->isNullable())
result.columns[prewhere_column_pos] = makeNullable(std::move(result.columns[prewhere_column_pos]));
result.clearFilter(); // Acting as a flag to not filter in PREWHERE
}

View File

@ -24,7 +24,11 @@ public:
MergeTreeRangeReader(
IMergeTreeReader * merge_tree_reader_,
MergeTreeRangeReader * prev_reader_,
const PrewhereInfoPtr & prewhere_,
ExpressionActionsPtr prewhere_alias_actions_,
ExpressionActionsPtr prewhere_actions_,
String prewhere_column_name_,
bool remove_prewhere_column_,
bool prewhere_need_filter_,
bool last_reader_in_chain_);
MergeTreeRangeReader() = default;
@ -217,7 +221,12 @@ private:
IMergeTreeReader * merge_tree_reader = nullptr;
const MergeTreeIndexGranularity * index_granularity = nullptr;
MergeTreeRangeReader * prev_reader = nullptr; /// If not nullptr, read from prev_reader firstly.
PrewhereInfoPtr prewhere;
ExpressionActionsPtr prewhere_alias_actions;
ExpressionActionsPtr prewhere_actions;
String prewhere_column_name;
bool remove_prewhere_column;
bool prewhere_need_filter;
Stream stream;

View File

@ -72,7 +72,7 @@ InputOrderInfoPtr ReadInOrderOptimizer::getInputOrder(const StoragePtr & storage
bool found_function = false;
for (const auto & action : elements_actions[i]->getActions())
{
if (action.type != ExpressionAction::APPLY_FUNCTION)
if (action.node->type != ActionsDAG::Type::FUNCTION)
continue;
if (found_function)
@ -83,13 +83,13 @@ InputOrderInfoPtr ReadInOrderOptimizer::getInputOrder(const StoragePtr & storage
else
found_function = true;
if (action.argument_names.size() != 1 || action.argument_names.at(0) != sorting_key_columns[i])
if (action.node->children.size() != 1 || action.node->children.at(0)->result_name != sorting_key_columns[i])
{
current_direction = 0;
break;
}
const auto & func = *action.function_base;
const auto & func = *action.node->function_base;
if (!func.hasInformationAboutMonotonicity())
{
current_direction = 0;

View File

@ -12,27 +12,30 @@ namespace DB
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ActionsDAG;
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
struct PrewhereInfo
{
/// Actions which are executed in order to alias columns are used for prewhere actions.
ExpressionActionsPtr alias_actions;
ActionsDAGPtr alias_actions;
/// Actions which are executed on block in order to get filter column for prewhere step.
ExpressionActionsPtr prewhere_actions;
ActionsDAGPtr prewhere_actions;
/// Actions which are executed after reading from storage in order to remove unused columns.
ExpressionActionsPtr remove_columns_actions;
ActionsDAGPtr remove_columns_actions;
String prewhere_column_name;
bool remove_prewhere_column = false;
bool need_filter = false;
PrewhereInfo() = default;
explicit PrewhereInfo(ExpressionActionsPtr prewhere_actions_, String prewhere_column_name_)
explicit PrewhereInfo(ActionsDAGPtr prewhere_actions_, String prewhere_column_name_)
: prewhere_actions(std::move(prewhere_actions_)), prewhere_column_name(std::move(prewhere_column_name_)) {}
};
/// Helper struct to store all the information about the filter expression.
struct FilterInfo
{
ExpressionActionsPtr actions;
ActionsDAGPtr actions;
String column_name;
bool do_remove_column = false;
};

View File

@ -278,7 +278,7 @@ Pipe StorageBuffer::read(
pipe_from_buffers.addSimpleTransform([&](const Block & header)
{
return std::make_shared<FilterTransform>(
header, query_info.prewhere_info->prewhere_actions,
header, query_info.prewhere_info->prewhere_actions->buildExpressions(),
query_info.prewhere_info->prewhere_column_name, query_info.prewhere_info->remove_prewhere_column);
});
@ -286,7 +286,7 @@ Pipe StorageBuffer::read(
{
pipe_from_buffers.addSimpleTransform([&](const Block & header)
{
return std::make_shared<ExpressionTransform>(header, query_info.prewhere_info->alias_actions);
return std::make_shared<ExpressionTransform>(header, query_info.prewhere_info->alias_actions->buildExpressions());
});
}
}

View File

@ -201,9 +201,9 @@ bool isExpressionActionsDeterministics(const ExpressionActionsPtr & actions)
{
for (const auto & action : actions->getActions())
{
if (action.type != ExpressionAction::APPLY_FUNCTION)
if (action.node->type != ActionsDAG::Type::FUNCTION)
continue;
if (!action.function_base->isDeterministic())
if (!action.node->function_base->isDeterministic())
return false;
}
return true;

View File

@ -30,7 +30,7 @@ TTLAggregateDescription::TTLAggregateDescription(const TTLAggregateDescription &
, expression_result_column_name(other.expression_result_column_name)
{
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
}
TTLAggregateDescription & TTLAggregateDescription::operator=(const TTLAggregateDescription & other)
@ -41,7 +41,7 @@ TTLAggregateDescription & TTLAggregateDescription::operator=(const TTLAggregateD
column_name = other.column_name;
expression_result_column_name = other.expression_result_column_name;
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
else
expression.reset();
return *this;
@ -54,9 +54,9 @@ void checkTTLExpression(const ExpressionActionsPtr & ttl_expression, const Strin
{
for (const auto & action : ttl_expression->getActions())
{
if (action.type == ExpressionAction::APPLY_FUNCTION)
if (action.node->type == ActionsDAG::Type::FUNCTION)
{
IFunctionBase & func = *action.function_base;
IFunctionBase & func = *action.node->function_base;
if (!func.isDeterministic())
throw Exception(
"TTL expression cannot contain non-deterministic functions, "
@ -92,10 +92,10 @@ TTLDescription::TTLDescription(const TTLDescription & other)
, recompression_codec(other.recompression_codec)
{
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
if (other.where_expression)
where_expression = std::make_shared<ExpressionActions>(*other.where_expression);
where_expression = other.where_expression->clone();
}
TTLDescription & TTLDescription::operator=(const TTLDescription & other)
@ -110,13 +110,13 @@ TTLDescription & TTLDescription::operator=(const TTLDescription & other)
expression_ast.reset();
if (other.expression)
expression = std::make_shared<ExpressionActions>(*other.expression);
expression = other.expression->clone();
else
expression.reset();
result_column = other.result_column;
if (other.where_expression)
where_expression = std::make_shared<ExpressionActions>(*other.where_expression);
where_expression = other.where_expression->clone();
else
where_expression.reset();

View File

@ -1,9 +1,9 @@
1 (NULL,'') a
1 (NULL,'') \N
1 (NULL,'') b
\N (123,'Hello') a
\N (123,'Hello') \N
\N (123,'Hello') b
1 (NULL,'') \N
3 (456,NULL) a
3 (456,NULL) \N
3 (456,NULL) b
3 (456,NULL) \N
\N (123,'Hello') a
\N (123,'Hello') b
\N (123,'Hello') \N

View File

@ -1 +1 @@
SELECT x, y, arrayJoin(['a', NULL, 'b']) AS z FROM system.one ARRAY JOIN [1, NULL, 3] AS x, [(NULL, ''), (123, 'Hello'), (456, NULL)] AS y;
SELECT x, y, arrayJoin(['a', NULL, 'b']) AS z FROM system.one ARRAY JOIN [1, NULL, 3] AS x, [(NULL, ''), (123, 'Hello'), (456, NULL)] AS y order by x, y, z;

View File

@ -1,7 +1,7 @@
Expression (Projection)
Header: x UInt8
Expression (Before ORDER BY and SELECT)
Header: _dummy UInt8
Header: dummy UInt8
1 UInt8
ReadFromStorage (Read from SystemOne)
Header: dummy UInt8