diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index 682f8a7a3ea..4fad806bef0 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -2668,13 +2668,17 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty return true; } -bool ExpressionAnalyzer::appendPrewhere(ExpressionActionsChain & chain, bool only_types) +bool ExpressionAnalyzer::appendPrewhere(ExpressionActionsChain & chain, bool only_types, const ASTPtr & sampling_expression) { assertSelect(); if (!select_query->prewhere_expression) return false; + Names required_sample_columns; + if (sampling_expression) + required_sample_columns = ExpressionAnalyzer(sampling_expression, context, nullptr, source_columns).getRequiredSourceColumns(); + initChain(chain, source_columns); auto & step = chain.getLastStep(); getRootActions(select_query->prewhere_expression, only_types, false, step.actions); @@ -2682,6 +2686,15 @@ bool ExpressionAnalyzer::appendPrewhere(ExpressionActionsChain & chain, bool onl step.required_output.push_back(prewhere_column_name); step.can_remove_required_output.push_back(true); + /// Add required columns for sample expression to required output in order not to remove them after + /// prewhere execution because sampling is executed after prewhere. + /// TODO: add sampling execution to common chain. + for (const auto & column : required_sample_columns) + { + step.required_output.push_back(column); + step.can_remove_required_output.push_back(true); + } + { /// Remove unused source_columns from prewhere actions. auto tmp_actions = std::make_shared(source_columns, context); diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.h b/dbms/src/Interpreters/ExpressionAnalyzer.h index 5e01f049c5c..8b11a8225a2 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.h +++ b/dbms/src/Interpreters/ExpressionAnalyzer.h @@ -142,7 +142,8 @@ public: bool appendArrayJoin(ExpressionActionsChain & chain, bool only_types); bool appendJoin(ExpressionActionsChain & chain, bool only_types); /// remove_filter is set in ExpressionActionsChain::finalize(); - bool appendPrewhere(ExpressionActionsChain & chain, bool only_types); + /// sampling_expression is needed if sampling is used in order to not remove columns are used in it. + bool appendPrewhere(ExpressionActionsChain & chain, bool only_types, const ASTPtr & sampling_expression); bool appendWhere(ExpressionActionsChain & chain, bool only_types); bool appendGroupBy(ExpressionActionsChain & chain, bool only_types); void appendAggregateFunctionsArguments(ExpressionActionsChain & chain, bool only_types); diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 01d3c28bedf..f2a14921557 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -47,6 +47,7 @@ #include #include #include +#include namespace DB @@ -63,6 +64,7 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; extern const int NOT_IMPLEMENTED; extern const int PARAMETER_OUT_OF_BOUND; + extern const int ARGUMENT_OUT_OF_BOUND; } InterpreterSelectQuery::InterpreterSelectQuery( @@ -279,7 +281,6 @@ BlockInputStreams InterpreterSelectQuery::executeWithMultipleStreams() return pipeline.streams; } - InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpressions(QueryProcessingStage::Enum from_stage, bool dry_run) { AnalysisResult res; @@ -305,7 +306,27 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression chain.finalize(); if (has_prewhere) - res.prewhere_info->remove_prewhere_column = chain.steps.at(0).can_remove_required_output.at(0); + { + const ExpressionActionsChain::Step & step = chain.steps.at(0); + res.prewhere_info->remove_prewhere_column = step.can_remove_required_output.at(0); + + Names columns_to_remove_after_sampling; + for (size_t i = 1; i < step.required_output.size(); ++i) + { + if (step.can_remove_required_output[i]) + columns_to_remove_after_sampling.push_back(step.required_output[i]); + } + + if (!columns_to_remove_after_sampling.empty()) + { + auto columns = res.prewhere_info->prewhere_actions->getSampleBlock().getNamesAndTypesList(); + ExpressionActionsPtr actions = std::make_shared(columns, context); + for (const auto & column : columns_to_remove_after_sampling) + actions->add(ExpressionAction::removeColumn(column)); + + res.prewhere_info->after_sampling_actions = std::move(actions); + } + } if (has_where) res.remove_where_filter = chain.steps.at(where_step_num).can_remove_required_output.at(0); @@ -317,7 +338,8 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression { ExpressionActionsChain chain(context); - if (query_analyzer->appendPrewhere(chain, !res.first_stage)) + ASTPtr sampling_expression = storage ? storage->getSamplingExpression() : nullptr; + if (query_analyzer->appendPrewhere(chain, !res.first_stage, sampling_expression)) { has_prewhere = true; diff --git a/dbms/src/Storages/IStorage.h b/dbms/src/Storages/IStorage.h index cbf69a18a77..b3c5075bc94 100644 --- a/dbms/src/Storages/IStorage.h +++ b/dbms/src/Storages/IStorage.h @@ -343,6 +343,9 @@ public: /// Returns data path if storage supports it, empty string otherwise. virtual String getDataPath() const { return {}; } + /// Returns sampling expression for storage or nullptr if there is no. + virtual ASTPtr getSamplingExpression() const { return nullptr; } + protected: using ITableDeclaration::ITableDeclaration; using std::enable_shared_from_this::shared_from_this; diff --git a/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp b/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp index b39f36807ad..19b3778fa31 100644 --- a/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp +++ b/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp @@ -593,6 +593,10 @@ BlockInputStreams MergeTreeDataSelectExecutor::readFromParts( stream = std::make_shared>( stream, std::make_shared(), used_sample_factor, "_sample_factor"); + if (query_info.prewhere_info && query_info.prewhere_info->after_sampling_actions) + for (auto & stream : res) + stream = std::make_shared(stream, query_info.prewhere_info->after_sampling_actions); + return res; } diff --git a/dbms/src/Storages/SelectQueryInfo.h b/dbms/src/Storages/SelectQueryInfo.h index a6e40e4c27d..d875e0cc7ee 100644 --- a/dbms/src/Storages/SelectQueryInfo.h +++ b/dbms/src/Storages/SelectQueryInfo.h @@ -21,10 +21,12 @@ using PreparedSets = std::unordered_map