Merge pull request #11950 from 4ertus2/joins

Extract JOIN in own plan step
This commit is contained in:
Nikolai Kochetov 2020-06-26 15:52:58 +03:00 committed by GitHub
commit 4f5be494c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 140 additions and 189 deletions

View File

@ -44,33 +44,4 @@ Block ExpressionBlockInputStream::readImpl()
return res; return res;
} }
Block InflatingExpressionBlockInputStream::readImpl()
{
if (!initialized)
{
if (expression->resultIsAlwaysEmpty())
return {};
initialized = true;
}
Block res;
bool keep_going = not_processed && not_processed->empty(); /// There's data inside expression.
if (!not_processed || keep_going)
{
not_processed.reset();
res = children.back()->read();
if (res || keep_going)
expression->execute(res, not_processed, action_number);
}
else
{
res = std::move(not_processed->block);
expression->execute(res, not_processed, action_number);
}
return res;
}
} }

View File

@ -323,8 +323,20 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings,
} }
} }
void ExpressionAction::execute(Block & block, ExtraBlockPtr & not_processed) const
{
switch (type)
{
case JOIN:
join->joinBlock(block, not_processed);
break;
void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_processed) const default:
throw Exception("Unexpected expression call", ErrorCodes::LOGICAL_ERROR);
}
}
void ExpressionAction::execute(Block & block, bool dry_run) const
{ {
size_t input_rows_count = block.rows(); size_t input_rows_count = block.rows();
@ -362,10 +374,7 @@ void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_
} }
case JOIN: case JOIN:
{ throw Exception("Unexpected JOIN expression call", ErrorCodes::LOGICAL_ERROR);
join->joinBlock(block, not_processed);
break;
}
case PROJECT: case PROJECT:
{ {
@ -676,19 +685,13 @@ void ExpressionActions::execute(Block & block, bool dry_run) const
} }
} }
/// @warning It's a tricky method that allows to continue ONLY ONE action in reason of one-to-many ALL JOIN logic. void ExpressionActions::execute(Block & block, ExtraBlockPtr & not_processed) const
void ExpressionActions::execute(Block & block, ExtraBlockPtr & not_processed, size_t & start_action) const
{ {
size_t i = start_action; if (actions.size() != 1)
start_action = 0; throw Exception("Continuation over multiple expressions is not supported", ErrorCodes::LOGICAL_ERROR);
for (; i < actions.size(); ++i)
{
actions[i].execute(block, false, not_processed);
checkLimits(block);
if (not_processed) actions[0].execute(block, not_processed);
start_action = i; checkLimits(block);
}
} }
bool ExpressionActions::hasJoinOrArrayJoin() const bool ExpressionActions::hasJoinOrArrayJoin() const

View File

@ -139,13 +139,8 @@ private:
void executeOnTotals(Block & block) const; void executeOnTotals(Block & block) const;
/// Executes action on block (modify it). Block could be splitted in case of JOIN. Then not_processed block is created. /// Executes action on block (modify it). Block could be splitted in case of JOIN. Then not_processed block is created.
void execute(Block & block, bool dry_run, ExtraBlockPtr & not_processed) const; void execute(Block & block, ExtraBlockPtr & not_processed) const;
void execute(Block & block, bool dry_run) const;
void execute(Block & block, bool dry_run) const
{
ExtraBlockPtr extra;
execute(block, dry_run, extra);
}
}; };
@ -211,8 +206,8 @@ public:
/// Execute the expression on the block. The block must contain all the columns returned by getRequiredColumns. /// Execute the expression on the block. The block must contain all the columns returned by getRequiredColumns.
void execute(Block & block, bool dry_run = false) const; void execute(Block & block, bool dry_run = false) const;
/// Execute the expression on the block with continuation. /// Execute the expression on the block with continuation. This method in only supported for single JOIN.
void execute(Block & block, ExtraBlockPtr & not_processed, size_t & start_action) const; void execute(Block & block, ExtraBlockPtr & not_processed) const;
bool hasJoinOrArrayJoin() const; bool hasJoinOrArrayJoin() const;
@ -325,10 +320,14 @@ struct ExpressionActionsChain
steps.clear(); steps.clear();
} }
ExpressionActionsPtr getLastActions() ExpressionActionsPtr getLastActions(bool allow_empty = false)
{ {
if (steps.empty()) if (steps.empty())
{
if (allow_empty)
return {};
throw Exception("Empty ExpressionActionsChain", ErrorCodes::LOGICAL_ERROR); throw Exception("Empty ExpressionActionsChain", ErrorCodes::LOGICAL_ERROR);
}
return steps.back().actions; return steps.back().actions;
} }
@ -341,6 +340,13 @@ struct ExpressionActionsChain
return steps.back(); return steps.back();
} }
Step & lastStep(const NamesAndTypesList & columns)
{
if (steps.empty())
steps.emplace_back(std::make_shared<ExpressionActions>(columns, context));
return steps.back();
}
std::string dumpChain(); std::string dumpChain();
}; };

View File

@ -453,14 +453,6 @@ const ASTSelectQuery * SelectQueryExpressionAnalyzer::getAggregatingQuery() cons
return getSelectQuery(); return getSelectQuery();
} }
void ExpressionAnalyzer::initChain(ExpressionActionsChain & chain, const NamesAndTypesList & columns) const
{
if (chain.steps.empty())
{
chain.steps.emplace_back(std::make_shared<ExpressionActions>(columns, context));
}
}
/// "Big" ARRAY JOIN. /// "Big" ARRAY JOIN.
void ExpressionAnalyzer::addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool array_join_is_left) const void ExpressionAnalyzer::addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool array_join_is_left) const
{ {
@ -487,8 +479,7 @@ bool SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & cha
if (!array_join_expression_list) if (!array_join_expression_list)
return false; return false;
initChain(chain, sourceColumns()); ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(array_join_expression_list, only_types, step.actions); getRootActions(array_join_expression_list, only_types, step.actions);
@ -502,18 +493,20 @@ void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions, JoinPtr j
actions->add(ExpressionAction::ordinaryJoin(syntax->analyzed_join, join)); actions->add(ExpressionAction::ordinaryJoin(syntax->analyzed_join, join));
} }
bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_types) bool SelectQueryExpressionAnalyzer::appendJoinLeftKeys(ExpressionActionsChain & chain, bool only_types)
{ {
const ASTTablesInSelectQueryElement * ast_join = getSelectQuery()->join(); ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
if (!ast_join)
return false;
JoinPtr table_join = makeTableJoin(*ast_join);
initChain(chain, sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(analyzedJoin().leftKeysList(), only_types, step.actions); getRootActions(analyzedJoin().leftKeysList(), only_types, step.actions);
return true;
}
bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain)
{
JoinPtr table_join = makeTableJoin(*syntax->ast_join);
ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
addJoinAction(step.actions, table_join); addJoinAction(step.actions, table_join);
return true; return true;
} }
@ -637,8 +630,7 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere(
if (!select_query->prewhere()) if (!select_query->prewhere())
return false; return false;
initChain(chain, sourceColumns()); auto & step = chain.lastStep(sourceColumns());
auto & step = chain.getLastStep();
getRootActions(select_query->prewhere(), only_types, step.actions); getRootActions(select_query->prewhere(), only_types, step.actions);
String prewhere_column_name = select_query->prewhere()->getColumnName(); String prewhere_column_name = select_query->prewhere()->getColumnName();
step.required_output.push_back(prewhere_column_name); step.required_output.push_back(prewhere_column_name);
@ -705,8 +697,7 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere(
void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name) void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name)
{ {
initChain(chain, sourceColumns()); ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
// FIXME: assert(filter_info); // FIXME: assert(filter_info);
step.actions = std::move(actions); step.actions = std::move(actions);
@ -723,8 +714,7 @@ bool SelectQueryExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain,
if (!select_query->where()) if (!select_query->where())
return false; return false;
initChain(chain, sourceColumns()); ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
step.required_output.push_back(select_query->where()->getColumnName()); step.required_output.push_back(select_query->where()->getColumnName());
step.can_remove_required_output = {true}; step.can_remove_required_output = {true};
@ -742,8 +732,7 @@ bool SelectQueryExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain
if (!select_query->groupBy()) if (!select_query->groupBy())
return false; return false;
initChain(chain, sourceColumns()); ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
ASTs asts = select_query->groupBy()->children; ASTs asts = select_query->groupBy()->children;
for (const auto & ast : asts) for (const auto & ast : asts)
@ -769,8 +758,7 @@ void SelectQueryExpressionAnalyzer::appendAggregateFunctionsArguments(Expression
{ {
const auto * select_query = getAggregatingQuery(); const auto * select_query = getAggregatingQuery();
initChain(chain, sourceColumns()); ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
for (const auto & desc : aggregate_descriptions) for (const auto & desc : aggregate_descriptions)
for (const auto & name : desc.argument_names) for (const auto & name : desc.argument_names)
@ -801,8 +789,7 @@ bool SelectQueryExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain,
if (!select_query->having()) if (!select_query->having())
return false; return false;
initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
step.required_output.push_back(select_query->having()->getColumnName()); step.required_output.push_back(select_query->having()->getColumnName());
getRootActions(select_query->having(), only_types, step.actions); getRootActions(select_query->having(), only_types, step.actions);
@ -814,8 +801,7 @@ void SelectQueryExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain,
{ {
const auto * select_query = getSelectQuery(); const auto * select_query = getSelectQuery();
initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(select_query->select(), only_types, step.actions); getRootActions(select_query->select(), only_types, step.actions);
@ -831,8 +817,7 @@ bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain
if (!select_query->orderBy()) if (!select_query->orderBy())
return false; return false;
initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(select_query->orderBy(), only_types, step.actions); getRootActions(select_query->orderBy(), only_types, step.actions);
@ -864,8 +849,7 @@ bool SelectQueryExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain
if (!select_query->limitBy()) if (!select_query->limitBy())
return false; return false;
initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(select_query->limitBy(), only_types, step.actions); getRootActions(select_query->limitBy(), only_types, step.actions);
@ -890,8 +874,7 @@ void SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain &
{ {
const auto * select_query = getSelectQuery(); const auto * select_query = getSelectQuery();
initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
NamesWithAliases result_columns; NamesWithAliases result_columns;
@ -939,8 +922,7 @@ void SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain &
void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const ASTPtr & expr, bool only_types) void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const ASTPtr & expr, bool only_types)
{ {
initChain(chain, sourceColumns()); ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(expr, only_types, step.actions); getRootActions(expr, only_types, step.actions);
step.required_output.push_back(expr->getColumnName()); step.required_output.push_back(expr->getColumnName());
} }
@ -1101,10 +1083,18 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
query_analyzer.appendArrayJoin(chain, only_types || !first_stage); query_analyzer.appendArrayJoin(chain, only_types || !first_stage);
if (query_analyzer.appendJoin(chain, only_types || !first_stage)) if (query_analyzer.hasTableJoin())
{ {
before_join = chain.getLastActions(); query_analyzer.appendJoinLeftKeys(chain, only_types || !first_stage);
if (!hasJoin())
before_join = chain.getLastActions(true);
if (before_join)
chain.addStep();
query_analyzer.appendJoin(chain);
join = chain.getLastActions();
if (!join)
throw Exception("No expected JOIN", ErrorCodes::LOGICAL_ERROR); throw Exception("No expected JOIN", ErrorCodes::LOGICAL_ERROR);
chain.addStep(); chain.addStep();
} }
@ -1153,11 +1143,11 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
} }
bool join_allow_read_in_order = true; bool join_allow_read_in_order = true;
if (before_join) if (hasJoin())
{ {
/// You may find it strange but we support read_in_order for HashJoin and do not support for MergeJoin. /// You may find it strange but we support read_in_order for HashJoin and do not support for MergeJoin.
auto join = before_join->getTableJoinAlgo(); auto join_algo = join->getTableJoinAlgo();
join_allow_read_in_order = typeid_cast<HashJoin *>(join.get()) && !join->hasStreamWithNonJoinedRows(); join_allow_read_in_order = typeid_cast<HashJoin *>(join_algo.get()) && !join_algo->hasStreamWithNonJoinedRows();
} }
optimize_read_in_order = optimize_read_in_order =

View File

@ -153,9 +153,6 @@ protected:
void analyzeAggregation(); void analyzeAggregation();
bool makeAggregateDescriptions(ExpressionActionsPtr & actions); bool makeAggregateDescriptions(ExpressionActionsPtr & actions);
/// columns - the columns that are present before the transformations begin.
void initChain(ExpressionActionsChain & chain, const NamesAndTypesList & columns) const;
const ASTSelectQuery * getSelectQuery() const; const ASTSelectQuery * getSelectQuery() const;
bool isRemoteStorage() const; bool isRemoteStorage() const;
@ -178,7 +175,8 @@ struct ExpressionAnalysisResult
bool optimize_read_in_order = false; bool optimize_read_in_order = false;
bool optimize_aggregation_in_order = false; bool optimize_aggregation_in_order = false;
ExpressionActionsPtr before_join; /// including JOIN ExpressionActionsPtr before_join;
ExpressionActionsPtr join;
ExpressionActionsPtr before_where; ExpressionActionsPtr before_where;
ExpressionActionsPtr before_aggregation; ExpressionActionsPtr before_aggregation;
ExpressionActionsPtr before_having; ExpressionActionsPtr before_having;
@ -214,7 +212,7 @@ struct ExpressionAnalysisResult
/// Filter for row-level security. /// Filter for row-level security.
bool hasFilter() const { return filter_info.get(); } bool hasFilter() const { return filter_info.get(); }
bool hasJoin() const { return before_join.get(); } bool hasJoin() const { return join.get(); }
bool hasPrewhere() const { return prewhere_info.get(); } bool hasPrewhere() const { return prewhere_info.get(); }
bool hasWhere() const { return before_where.get(); } bool hasWhere() const { return before_where.get(); }
bool hasHaving() const { return before_having.get(); } bool hasHaving() const { return before_having.get(); }
@ -249,6 +247,7 @@ public:
/// Does the expression have aggregate functions or a GROUP BY or HAVING section. /// Does the expression have aggregate functions or a GROUP BY or HAVING section.
bool hasAggregation() const { return has_aggregation; } bool hasAggregation() const { return has_aggregation; }
bool hasGlobalSubqueries() { return has_global_subqueries; } bool hasGlobalSubqueries() { return has_global_subqueries; }
bool hasTableJoin() const { return syntax->ast_join; }
const NamesAndTypesList & aggregationKeys() const { return aggregation_keys; } const NamesAndTypesList & aggregationKeys() const { return aggregation_keys; }
const AggregateDescriptions & aggregates() const { return aggregate_descriptions; } const AggregateDescriptions & aggregates() const { return aggregate_descriptions; }
@ -307,7 +306,8 @@ private:
/// Before aggregation: /// Before aggregation:
bool appendArrayJoin(ExpressionActionsChain & chain, bool only_types); bool appendArrayJoin(ExpressionActionsChain & chain, bool only_types);
bool appendJoin(ExpressionActionsChain & chain, bool only_types); bool appendJoinLeftKeys(ExpressionActionsChain & chain, bool only_types);
bool appendJoin(ExpressionActionsChain & chain);
/// Add preliminary rows filtration. Actions are created in other expression analyzer to prevent any possible alias injection. /// Add preliminary rows filtration. Actions are created in other expression analyzer to prevent any possible alias injection.
void appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name); void appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name);
/// remove_filter is set in ExpressionActionsChain::finalize(); /// remove_filter is set in ExpressionActionsChain::finalize();

View File

@ -32,6 +32,7 @@
#include <Processors/Pipe.h> #include <Processors/Pipe.h>
#include <Processors/Sources/SourceFromInputStream.h> #include <Processors/Sources/SourceFromInputStream.h>
#include <Processors/Transforms/ExpressionTransform.h> #include <Processors/Transforms/ExpressionTransform.h>
#include <Processors/Transforms/InflatingExpressionTransform.h>
#include <Processors/Transforms/AggregatingTransform.h> #include <Processors/Transforms/AggregatingTransform.h>
#include <Processors/QueryPlan/ReadFromStorageStep.h> #include <Processors/QueryPlan/ReadFromStorageStep.h>
#include <Processors/QueryPlan/ExpressionStep.h> #include <Processors/QueryPlan/ExpressionStep.h>
@ -858,52 +859,38 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, const BlockInpu
query_plan.addStep(std::move(row_level_security_step)); query_plan.addStep(std::move(row_level_security_step));
} }
if (expressions.before_join)
{
QueryPlanStepPtr before_join_step = std::make_unique<ExpressionStep>(
query_plan.getCurrentDataStream(),
expressions.before_join);
before_join_step->setStepDescription("Before JOIN");
query_plan.addStep(std::move(before_join_step));
}
if (expressions.hasJoin()) if (expressions.hasJoin())
{ {
Block join_result_sample; Block join_result_sample;
JoinPtr join = expressions.before_join->getTableJoinAlgo(); JoinPtr join = expressions.join->getTableJoinAlgo();
join_result_sample = ExpressionTransform::transformHeader(query_plan.getCurrentDataStream().header, expressions.before_join); join_result_sample = InflatingExpressionTransform::transformHeader(
query_plan.getCurrentDataStream().header, expressions.join);
bool inflating_join = false; QueryPlanStepPtr join_step = std::make_unique<InflatingExpressionStep>(
if (join) query_plan.getCurrentDataStream(),
expressions.join);
join_step->setStepDescription("JOIN");
query_plan.addStep(std::move(join_step));
if (auto stream = join->createStreamWithNonJoinedRows(join_result_sample, settings.max_block_size))
{ {
inflating_join = true; auto source = std::make_shared<SourceFromInputStream>(std::move(stream));
if (auto * hash_join = typeid_cast<HashJoin *>(join.get())) auto add_non_joined_rows_step = std::make_unique<AddingDelayedStreamStep>(
inflating_join = isCross(hash_join->getKind()); query_plan.getCurrentDataStream(), std::move(source));
}
QueryPlanStepPtr before_join_step; add_non_joined_rows_step->setStepDescription("Add non-joined rows after JOIN");
if (inflating_join) query_plan.addStep(std::move(add_non_joined_rows_step));
{
before_join_step = std::make_unique<InflatingExpressionStep>(
query_plan.getCurrentDataStream(),
expressions.before_join,
true);
}
else
{
before_join_step = std::make_unique<ExpressionStep>(
query_plan.getCurrentDataStream(),
expressions.before_join,
true);
}
before_join_step->setStepDescription("JOIN");
query_plan.addStep(std::move(before_join_step));
if (join)
{
if (auto stream = join->createStreamWithNonJoinedRows(join_result_sample, settings.max_block_size))
{
auto source = std::make_shared<SourceFromInputStream>(std::move(stream));
auto add_non_joined_rows_step = std::make_unique<AddingDelayedStreamStep>(
query_plan.getCurrentDataStream(), std::move(source));
add_non_joined_rows_step->setStepDescription("Add non-joined rows after JOIN");
query_plan.addStep(std::move(add_non_joined_rows_step));
}
} }
} }

View File

@ -1016,6 +1016,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect(
result.aggregates = getAggregates(query, *select_query); result.aggregates = getAggregates(query, *select_query);
result.collectUsedColumns(query, true); result.collectUsedColumns(query, true);
result.ast_join = select_query->join();
if (result.optimize_trivial_count) if (result.optimize_trivial_count)
result.optimize_trivial_count = settings.optimize_trivial_count_query && result.optimize_trivial_count = settings.optimize_trivial_count_query &&

View File

@ -11,6 +11,7 @@ namespace DB
{ {
class ASTFunction; class ASTFunction;
struct ASTTablesInSelectQueryElement;
class TableJoin; class TableJoin;
class Context; class Context;
struct Settings; struct Settings;
@ -24,6 +25,7 @@ struct SyntaxAnalyzerResult
ConstStoragePtr storage; ConstStoragePtr storage;
StorageMetadataPtr metadata_snapshot; StorageMetadataPtr metadata_snapshot;
std::shared_ptr<TableJoin> analyzed_join; std::shared_ptr<TableJoin> analyzed_join;
const ASTTablesInSelectQueryElement * ast_join = nullptr;
NamesAndTypesList source_columns; NamesAndTypesList source_columns;
NameSet source_columns_set; /// Set of names of source_columns. NameSet source_columns_set; /// Set of names of source_columns.

View File

@ -28,13 +28,12 @@ static void filterDistinctColumns(const Block & res_header, NameSet & distinct_c
distinct_columns.swap(new_distinct_columns); distinct_columns.swap(new_distinct_columns);
} }
ExpressionStep::ExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_, bool default_totals_) ExpressionStep::ExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_)
: ITransformingStep( : ITransformingStep(
input_stream_, input_stream_,
ExpressionTransform::transformHeader(input_stream_.header, expression_), Transform::transformHeader(input_stream_.header, expression_),
getTraits(expression_)) getTraits(expression_))
, expression(std::move(expression_)) , expression(std::move(expression_))
, default_totals(default_totals_)
{ {
/// Some columns may be removed by expression. /// Some columns may be removed by expression.
/// TODO: also check aliases, functions and some types of join /// TODO: also check aliases, functions and some types of join
@ -44,28 +43,19 @@ ExpressionStep::ExpressionStep(const DataStream & input_stream_, ExpressionActio
void ExpressionStep::transformPipeline(QueryPipeline & pipeline) void ExpressionStep::transformPipeline(QueryPipeline & pipeline)
{ {
/// In case joined subquery has totals, and we don't, add default chunk to totals.
bool add_default_totals = false;
if (default_totals && !pipeline.hasTotals())
{
pipeline.addDefaultTotals();
add_default_totals = true;
}
pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type) pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type)
{ {
bool on_totals = stream_type == QueryPipeline::StreamType::Totals; bool on_totals = stream_type == QueryPipeline::StreamType::Totals;
return std::make_shared<ExpressionTransform>(header, expression, on_totals, add_default_totals); return std::make_shared<Transform>(header, expression, on_totals);
}); });
} }
InflatingExpressionStep::InflatingExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_, bool default_totals_) InflatingExpressionStep::InflatingExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_)
: ITransformingStep( : ITransformingStep(
input_stream_, input_stream_,
ExpressionTransform::transformHeader(input_stream_.header, expression_), Transform::transformHeader(input_stream_.header, expression_),
getTraits(expression_)) getTraits(expression_))
, expression(std::move(expression_)) , expression(std::move(expression_))
, default_totals(default_totals_)
{ {
filterDistinctColumns(output_stream->header, output_stream->distinct_columns); filterDistinctColumns(output_stream->header, output_stream->distinct_columns);
filterDistinctColumns(output_stream->header, output_stream->local_distinct_columns); filterDistinctColumns(output_stream->header, output_stream->local_distinct_columns);
@ -75,7 +65,7 @@ void InflatingExpressionStep::transformPipeline(QueryPipeline & pipeline)
{ {
/// In case joined subquery has totals, and we don't, add default chunk to totals. /// In case joined subquery has totals, and we don't, add default chunk to totals.
bool add_default_totals = false; bool add_default_totals = false;
if (default_totals && !pipeline.hasTotals()) if (!pipeline.hasTotals())
{ {
pipeline.addDefaultTotals(); pipeline.addDefaultTotals();
add_default_totals = true; add_default_totals = true;
@ -84,7 +74,7 @@ void InflatingExpressionStep::transformPipeline(QueryPipeline & pipeline)
pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type) pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type)
{ {
bool on_totals = stream_type == QueryPipeline::StreamType::Totals; bool on_totals = stream_type == QueryPipeline::StreamType::Totals;
return std::make_shared<InflatingExpressionTransform>(header, expression, on_totals, add_default_totals); return std::make_shared<Transform>(header, expression, on_totals, add_default_totals);
}); });
} }

View File

@ -7,31 +7,36 @@ namespace DB
class ExpressionActions; class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>; using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ExpressionTransform;
class InflatingExpressionTransform;
class ExpressionStep : public ITransformingStep class ExpressionStep : public ITransformingStep
{ {
public: public:
explicit ExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_, bool default_totals_ = false); using Transform = ExpressionTransform;
explicit ExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_);
String getName() const override { return "Expression"; } String getName() const override { return "Expression"; }
void transformPipeline(QueryPipeline & pipeline) override; void transformPipeline(QueryPipeline & pipeline) override;
private: private:
ExpressionActionsPtr expression; ExpressionActionsPtr expression;
bool default_totals; /// See ExpressionTransform
}; };
/// TODO: add separate step for join. /// TODO: add separate step for join.
class InflatingExpressionStep : public ITransformingStep class InflatingExpressionStep : public ITransformingStep
{ {
public: public:
explicit InflatingExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_, bool default_totals_ = false); using Transform = InflatingExpressionTransform;
explicit InflatingExpressionStep(const DataStream & input_stream_, ExpressionActionsPtr expression_);
String getName() const override { return "Expression"; } String getName() const override { return "Expression"; }
void transformPipeline(QueryPipeline & pipeline) override; void transformPipeline(QueryPipeline & pipeline) override;
private: private:
ExpressionActionsPtr expression; ExpressionActionsPtr expression;
bool default_totals; /// See ExpressionTransform
}; };
} }

View File

@ -12,11 +12,10 @@ Block ExpressionTransform::transformHeader(Block header, const ExpressionActions
} }
ExpressionTransform::ExpressionTransform(const Block & header_, ExpressionActionsPtr expression_, bool on_totals_, bool default_totals_) ExpressionTransform::ExpressionTransform(const Block & header_, ExpressionActionsPtr expression_, bool on_totals_)
: ISimpleTransform(header_, transformHeader(header_, expression_), on_totals_) : ISimpleTransform(header_, transformHeader(header_, expression_), on_totals_)
, expression(std::move(expression_)) , expression(std::move(expression_))
, on_totals(on_totals_) , on_totals(on_totals_)
, default_totals(default_totals_)
{ {
} }
@ -37,14 +36,7 @@ void ExpressionTransform::transform(Chunk & chunk)
auto block = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns()); auto block = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns());
if (on_totals) if (on_totals)
{
/// Drop totals if both out stream and joined stream doesn't have ones.
/// See comment in ExpressionTransform.h
if (default_totals && !expression->hasTotalsInJoin())
return;
expression->executeOnTotals(block); expression->executeOnTotals(block);
}
else else
expression->execute(block); expression->execute(block);

View File

@ -13,8 +13,7 @@ public:
ExpressionTransform( ExpressionTransform(
const Block & header_, const Block & header_,
ExpressionActionsPtr expression_, ExpressionActionsPtr expression_,
bool on_totals_ = false, bool on_totals_ = false);
bool default_totals_ = false);
String getName() const override { return "ExpressionTransform"; } String getName() const override { return "ExpressionTransform"; }
@ -26,10 +25,6 @@ protected:
private: private:
ExpressionActionsPtr expression; ExpressionActionsPtr expression;
bool on_totals; bool on_totals;
/// This flag means that we have manually added totals to our pipeline.
/// It may happen in case if joined subquery has totals, but out string doesn't.
/// We need to join default values with subquery totals if we have them, or return empty chunk is haven't.
bool default_totals;
bool initialized = false; bool initialized = false;
}; };

View File

@ -5,9 +5,10 @@
namespace DB namespace DB
{ {
static Block transformHeader(Block header, const ExpressionActionsPtr & expression) Block InflatingExpressionTransform::transformHeader(Block header, const ExpressionActionsPtr & expression)
{ {
expression->execute(header, true); ExtraBlockPtr tmp;
expression->execute(header, tmp);
return header; return header;
} }
@ -38,8 +39,12 @@ void InflatingExpressionTransform::transform(Chunk & chunk)
{ {
/// We have to make chunk empty before return /// We have to make chunk empty before return
block = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns()); block = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns());
/// Drop totals if both out stream and joined stream doesn't have ones.
/// See comment in ExpressionTransform.h
if (default_totals && !expression->hasTotalsInJoin()) if (default_totals && !expression->hasTotalsInJoin())
return; return;
expression->executeOnTotals(block); expression->executeOnTotals(block);
} }
else else
@ -59,7 +64,7 @@ Block InflatingExpressionTransform::readExecute(Chunk & chunk)
res = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns()); res = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns());
if (res) if (res)
expression->execute(res, not_processed, action_number); expression->execute(res, not_processed);
} }
else if (not_processed->empty()) /// There's not processed data inside expression. else if (not_processed->empty()) /// There's not processed data inside expression.
{ {
@ -67,12 +72,12 @@ Block InflatingExpressionTransform::readExecute(Chunk & chunk)
res = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns()); res = getInputPort().getHeader().cloneWithColumns(chunk.detachColumns());
not_processed.reset(); not_processed.reset();
expression->execute(res, not_processed, action_number); expression->execute(res, not_processed);
} }
else else
{ {
res = std::move(not_processed->block); res = std::move(not_processed->block);
expression->execute(res, not_processed, action_number); expression->execute(res, not_processed);
} }
return res; return res;
} }

View File

@ -16,6 +16,8 @@ public:
String getName() const override { return "InflatingExpressionTransform"; } String getName() const override { return "InflatingExpressionTransform"; }
static Block transformHeader(Block header, const ExpressionActionsPtr & expression);
protected: protected:
void transform(Chunk & chunk) override; void transform(Chunk & chunk) override;
bool needInputData() const override { return !not_processed; } bool needInputData() const override { return !not_processed; }
@ -23,11 +25,13 @@ protected:
private: private:
ExpressionActionsPtr expression; ExpressionActionsPtr expression;
bool on_totals; bool on_totals;
/// This flag means that we have manually added totals to our pipeline.
/// It may happen in case if joined subquery has totals, but out string doesn't.
/// We need to join default values with subquery totals if we have them, or return empty chunk is haven't.
bool default_totals; bool default_totals;
bool initialized = false; bool initialized = false;
ExtraBlockPtr not_processed; ExtraBlockPtr not_processed;
size_t action_number = 0;
Block readExecute(Chunk & chunk); Block readExecute(Chunk & chunk);
}; };