Merge pull request #60313 from ClickHouse/analyzer-refactor-constant-name

Analyzer: Refactor execution name for ConstantNode
This commit is contained in:
Dmitry Novik 2024-03-08 12:08:05 +01:00 committed by GitHub
commit 526af77f4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 215 additions and 65 deletions

View File

@ -1,5 +1,7 @@
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include <Common/assert_cast.h>
#include <Common/FieldVisitorToString.h>
#include <Common/SipHash.h>
@ -38,6 +40,70 @@ ConstantNode::ConstantNode(Field value_)
: ConstantNode(value_, applyVisitor(FieldToDataType(), value_))
{}
bool ConstantNode::requiresCastCall() const
{
const auto & constant_value_literal = constant_value->getValue();
bool need_to_add_cast_function = false;
auto constant_value_literal_type = constant_value_literal.getType();
WhichDataType constant_value_type(constant_value->getType());
switch (constant_value_literal_type)
{
case Field::Types::String:
{
need_to_add_cast_function = !constant_value_type.isString();
break;
}
case Field::Types::UInt64:
case Field::Types::Int64:
case Field::Types::Float64:
{
WhichDataType constant_value_field_type(applyVisitor(FieldToDataType(), constant_value_literal));
need_to_add_cast_function = constant_value_field_type.idx != constant_value_type.idx;
break;
}
case Field::Types::Int128:
case Field::Types::UInt128:
case Field::Types::Int256:
case Field::Types::UInt256:
case Field::Types::Decimal32:
case Field::Types::Decimal64:
case Field::Types::Decimal128:
case Field::Types::Decimal256:
case Field::Types::AggregateFunctionState:
case Field::Types::Array:
case Field::Types::Tuple:
case Field::Types::Map:
case Field::Types::UUID:
case Field::Types::Bool:
case Field::Types::Object:
case Field::Types::IPv4:
case Field::Types::IPv6:
case Field::Types::Null:
case Field::Types::CustomType:
{
need_to_add_cast_function = true;
break;
}
}
// Add cast if constant was created as a result of constant folding.
// Constant folding may lead to type transformation and literal on shard
// may have a different type.
return need_to_add_cast_function || source_expression != nullptr;
}
bool ConstantNode::receivedFromInitiatorServer() const
{
if (!hasSourceExpression())
return false;
auto * cast_function = getSourceExpression()->as<FunctionNode>();
if (!cast_function || cast_function->getFunctionName() != "_CAST")
return false;
return true;
}
void ConstantNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
{
buffer << std::string(indent, ' ') << "CONSTANT id: " << format_state.getNodeId(this);
@ -89,54 +155,7 @@ ASTPtr ConstantNode::toASTImpl(const ConvertToASTOptions & options) const
if (!options.add_cast_for_constants)
return constant_value_ast;
bool need_to_add_cast_function = false;
auto constant_value_literal_type = constant_value_literal.getType();
WhichDataType constant_value_type(constant_value->getType());
switch (constant_value_literal_type)
{
case Field::Types::String:
{
need_to_add_cast_function = !constant_value_type.isString();
break;
}
case Field::Types::UInt64:
case Field::Types::Int64:
case Field::Types::Float64:
{
WhichDataType constant_value_field_type(applyVisitor(FieldToDataType(), constant_value_literal));
need_to_add_cast_function = constant_value_field_type.idx != constant_value_type.idx;
break;
}
case Field::Types::Int128:
case Field::Types::UInt128:
case Field::Types::Int256:
case Field::Types::UInt256:
case Field::Types::Decimal32:
case Field::Types::Decimal64:
case Field::Types::Decimal128:
case Field::Types::Decimal256:
case Field::Types::AggregateFunctionState:
case Field::Types::Array:
case Field::Types::Tuple:
case Field::Types::Map:
case Field::Types::UUID:
case Field::Types::Bool:
case Field::Types::Object:
case Field::Types::IPv4:
case Field::Types::IPv6:
case Field::Types::Null:
case Field::Types::CustomType:
{
need_to_add_cast_function = true;
break;
}
}
// Add cast if constant was created as a result of constant folding.
// Constant folding may lead to type transformation and literal on shard
// may have a different type.
if (need_to_add_cast_function || source_expression != nullptr)
if (requiresCastCall())
{
auto constant_type_name_ast = std::make_shared<ASTLiteral>(constant_value->getType()->getName());
return makeASTFunction("_CAST", std::move(constant_value_ast), std::move(constant_type_name_ast));

View File

@ -75,6 +75,12 @@ public:
return constant_value->getType();
}
/// Check if conversion to AST requires wrapping with _CAST function.
bool requiresCastCall() const;
/// Check if constant is a result of _CAST function constant folding.
bool receivedFromInitiatorServer() const;
void setMaskId(size_t id)
{
mask_id = id;

View File

@ -1362,7 +1362,7 @@ ActionsDAGPtr ActionsDAG::makeConvertingActions(
size_t num_result_columns = result.size();
if (mode == MatchColumnsMode::Position && num_input_columns != num_result_columns)
throw Exception(ErrorCodes::NUMBER_OF_COLUMNS_DOESNT_MATCH, "Number of columns doesn't match");
throw Exception(ErrorCodes::NUMBER_OF_COLUMNS_DOESNT_MATCH, "Number of columns doesn't match (source: {} and result: {})", num_input_columns, num_result_columns);
if (add_casted_columns && mode != MatchColumnsMode::Name)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Converting with add_casted_columns supported only for MatchColumnsMode::Name");

View File

@ -142,8 +142,9 @@ Block getHeaderForProcessingStage(
if (context->getSettingsRef().allow_experimental_analyzer)
{
auto storage = std::make_shared<StorageDummy>(
storage_snapshot->storage.getStorageID(), storage_snapshot->metadata->getColumns(), storage_snapshot);
auto storage = std::make_shared<StorageDummy>(storage_snapshot->storage.getStorageID(),
storage_snapshot->getAllColumnsDescription(),
storage_snapshot);
InterpreterSelectQueryAnalyzer interpreter(query, context, storage, SelectQueryOptions(processed_stage).analyze());
result = interpreter.getSampleBlock();
}

View File

@ -1181,7 +1181,7 @@ PlannerContextPtr buildPlannerContext(const QueryTreeNodePtr & query_tree_node,
if (select_query_options.is_subquery)
updateContextForSubqueryExecution(mutable_context);
return std::make_shared<PlannerContext>(mutable_context, std::move(global_planner_context));
return std::make_shared<PlannerContext>(mutable_context, std::move(global_planner_context), select_query_options);
}
Planner::Planner(const QueryTreeNodePtr & query_tree_,

View File

@ -44,6 +44,27 @@ namespace ErrorCodes
namespace
{
/* Calculates Action node name for ConstantNode.
*
* If converting to AST will add a '_CAST' function call,
* the result action name will also include it.
*/
String calculateActionNodeNameWithCastIfNeeded(const ConstantNode & constant_node)
{
WriteBufferFromOwnString buffer;
if (constant_node.requiresCastCall())
buffer << "_CAST(";
buffer << calculateConstantActionNodeName(constant_node.getValue(), constant_node.getResultType());
if (constant_node.requiresCastCall())
{
buffer << ", '" << constant_node.getResultType()->getName() << "'_String)";
}
return buffer.str();
}
class ActionNodeNameHelper
{
public:
@ -88,7 +109,49 @@ public:
case QueryTreeNodeType::CONSTANT:
{
const auto & constant_node = node->as<ConstantNode &>();
result = calculateConstantActionNodeName(constant_node.getValue(), constant_node.getResultType());
/* To ensure that headers match during distributed query we need to simulate action node naming on
* secondary servers. If we don't do that headers will mismatch due to constant folding.
*
* +--------+
* -----------------| Server |----------------
* / +--------+ \
* / \
* v v
* +-----------+ +-----------+
* | Initiator | ------ | Secondary |------
* +-----------+ / +-----------+ \
* | / \
* | / \
* v / \
* +---------------+ v v
* | Wrap in _CAST | +----------------------------+ +----------------------+
* | if needed | | Constant folded from _CAST | | Constant folded from |
* +---------------+ +----------------------------+ | another expression |
* | +----------------------+
* v |
* +----------------------------+ v
* | Name ConstantNode the same | +--------------------------+
* | as on initiator server | | Generate action name for |
* | (wrap in _CAST if needed) | | original expression |
* +----------------------------+ +--------------------------+
*/
if (planner_context.isASTLevelOptimizationAllowed())
{
result = calculateActionNodeNameWithCastIfNeeded(constant_node);
}
else
{
// Need to check if constant folded from QueryNode until https://github.com/ClickHouse/ClickHouse/issues/60847 is fixed.
if (constant_node.hasSourceExpression() && constant_node.getSourceExpression()->getNodeType() != QueryTreeNodeType::QUERY)
{
if (constant_node.receivedFromInitiatorServer())
result = calculateActionNodeNameWithCastIfNeeded(constant_node);
else
result = calculateActionNodeName(constant_node.getSourceExpression());
}
else
result = calculateConstantActionNodeName(constant_node.getValue(), constant_node.getResultType());
}
break;
}
case QueryTreeNodeType::FUNCTION:
@ -530,7 +593,52 @@ PlannerActionsVisitorImpl::NodeNameAndNodeMinLevel PlannerActionsVisitorImpl::vi
const auto & constant_literal = constant_node.getValue();
const auto & constant_type = constant_node.getResultType();
auto constant_node_name = calculateConstantActionNodeName(constant_literal, constant_type);
auto constant_node_name = [&]()
{
/* To ensure that headers match during distributed query we need to simulate action node naming on
* secondary servers. If we don't do that headers will mismatch due to constant folding.
*
* +--------+
* -----------------| Server |----------------
* / +--------+ \
* / \
* v v
* +-----------+ +-----------+
* | Initiator | ------ | Secondary |------
* +-----------+ / +-----------+ \
* | / \
* | / \
* v / \
* +---------------+ v v
* | Wrap in _CAST | +----------------------------+ +----------------------+
* | if needed | | Constant folded from _CAST | | Constant folded from |
* +---------------+ +----------------------------+ | another expression |
* | +----------------------+
* v |
* +----------------------------+ v
* | Name ConstantNode the same | +--------------------------+
* | as on initiator server | | Generate action name for |
* | (wrap in _CAST if needed) | | original expression |
* +----------------------------+ +--------------------------+
*/
if (planner_context->isASTLevelOptimizationAllowed())
{
return calculateActionNodeNameWithCastIfNeeded(constant_node);
}
else
{
// Need to check if constant folded from QueryNode until https://github.com/ClickHouse/ClickHouse/issues/60847 is fixed.
if (constant_node.hasSourceExpression() && constant_node.getSourceExpression()->getNodeType() != QueryTreeNodeType::QUERY)
{
if (constant_node.receivedFromInitiatorServer())
return calculateActionNodeNameWithCastIfNeeded(constant_node);
else
return action_node_name_helper.calculateActionNodeName(constant_node.getSourceExpression());
}
else
return calculateConstantActionNodeName(constant_literal, constant_type);
}
}();
ColumnWithTypeAndName column;
column.name = constant_node_name;

View File

@ -3,6 +3,7 @@
#include <Analyzer/TableNode.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/ConstantNode.h>
#include <Interpreters/Context.h>
namespace DB
{
@ -41,9 +42,10 @@ bool GlobalPlannerContext::hasColumnIdentifier(const ColumnIdentifier & column_i
return column_identifiers.contains(column_identifier);
}
PlannerContext::PlannerContext(ContextMutablePtr query_context_, GlobalPlannerContextPtr global_planner_context_)
PlannerContext::PlannerContext(ContextMutablePtr query_context_, GlobalPlannerContextPtr global_planner_context_, const SelectQueryOptions & select_query_options_)
: query_context(std::move(query_context_))
, global_planner_context(std::move(global_planner_context_))
, is_ast_level_optimization_allowed(!(query_context->getClientInfo().query_kind == ClientInfo::QueryKind::SECONDARY_QUERY || select_query_options_.ignore_ast_optimizations))
{}
TableExpressionData & PlannerContext::getOrCreateTableExpressionData(const QueryTreeNodePtr & table_expression_node)

View File

@ -10,6 +10,7 @@
#include <Analyzer/IQueryTreeNode.h>
#include <Planner/TableExpressionData.h>
#include <Interpreters/SelectQueryOptions.h>
namespace DB
{
@ -78,7 +79,7 @@ class PlannerContext
{
public:
/// Create planner context with query context and global planner context
PlannerContext(ContextMutablePtr query_context_, GlobalPlannerContextPtr global_planner_context_);
PlannerContext(ContextMutablePtr query_context_, GlobalPlannerContextPtr global_planner_context_, const SelectQueryOptions & select_query_options_);
/// Get planner context query context
ContextPtr getQueryContext() const
@ -165,6 +166,12 @@ public:
static SetKey createSetKey(const DataTypePtr & left_operand_type, const QueryTreeNodePtr & set_source_node);
PreparedSets & getPreparedSets() { return prepared_sets; }
/// Returns false if any of following conditions met:
/// 1. Query is executed on a follower node.
/// 2. ignore_ast_optimizations is set.
bool isASTLevelOptimizationAllowed() const { return is_ast_level_optimization_allowed; }
private:
/// Query context
ContextMutablePtr query_context;
@ -172,6 +179,8 @@ private:
/// Global planner context
GlobalPlannerContextPtr global_planner_context;
bool is_ast_level_optimization_allowed;
/// Column node to column identifier
std::unordered_map<QueryTreeNodePtr, ColumnIdentifier> column_node_to_column_identifier;

View File

@ -296,7 +296,6 @@ VirtualColumnsDescription StorageDistributed::createVirtuals()
StorageInMemoryMetadata metadata;
auto desc = MergeTreeData::createVirtuals(metadata);
desc.addEphemeral("_table", std::make_shared<DataTypeLowCardinality>(std::make_shared<DataTypeString>()), "Name of a table");
desc.addEphemeral("_shard_num", std::make_shared<DataTypeUInt32>(), "Deprecated. Use function shardNum instead");
return desc;

View File

@ -1047,7 +1047,7 @@ QueryPipelineBuilderPtr ReadFromMerge::createSources(
Block pipe_header = builder->getHeader();
if (has_database_virtual_column && !pipe_header.has("_database"))
if (has_database_virtual_column && common_header.has("_database") && !pipe_header.has("_database"))
{
ColumnWithTypeAndName column;
column.name = "_database";
@ -1062,7 +1062,7 @@ QueryPipelineBuilderPtr ReadFromMerge::createSources(
{ return std::make_shared<ExpressionTransform>(stream_header, adding_column_actions); });
}
if (has_table_virtual_column && !pipe_header.has("_table"))
if (has_table_virtual_column && common_header.has("_table") && !pipe_header.has("_table"))
{
ColumnWithTypeAndName column;
column.name = "_table";

View File

@ -69,6 +69,14 @@ std::shared_ptr<StorageSnapshot> StorageSnapshot::clone(DataPtr data_) const
return res;
}
ColumnsDescription StorageSnapshot::getAllColumnsDescription() const
{
auto get_column_options = GetColumnsOptions(GetColumnsOptions::All).withExtendedObjects().withVirtuals();
auto column_names_and_types = getColumns(get_column_options);
return ColumnsDescription{column_names_and_types};
}
NamesAndTypesList StorageSnapshot::getColumns(const GetColumnsOptions & options) const
{
auto all_columns = getMetadataForQuery()->getColumns().get(options);

View File

@ -55,6 +55,9 @@ struct StorageSnapshot
std::shared_ptr<StorageSnapshot> clone(DataPtr data_) const;
/// Get columns description
ColumnsDescription getAllColumnsDescription() const;
/// Get all available columns with types according to options.
NamesAndTypesList getColumns(const GetColumnsOptions & options) const;

View File

@ -2,4 +2,3 @@ test_build_sets_from_multiple_threads/test.py::test_set
test_concurrent_backups_s3/test.py::test_concurrent_backups
test_distributed_type_object/test.py::test_distributed_type_object
test_merge_table_over_distributed/test.py::test_global_in
test_merge_table_over_distributed/test.py::test_select_table_name_from_merge_over_distributed

View File

@ -36,7 +36,7 @@ Header: avgWeighted(x, y) Nullable(Float64)
Header: x Nullable(Nothing)
y UInt8
Expression (Projection)
Header: NULL_Nullable(Nothing) Nullable(Nothing)
Header: _CAST(NULL_Nullable(Nothing), \'Nullable(Nothing)\'_String) Nullable(Nothing)
1_UInt8 UInt8
Expression (Change column names to column identifiers)
Header: __table5.dummy UInt8

View File

@ -60,7 +60,6 @@ DESCRIBE remote(default, currentDatabase(), t_describe_options) FORMAT PrettyCom
│ _part_offset │ UInt64 │ │ │ Number of row in the part │ │ │ 1 │
│ _row_exists │ UInt8 │ │ │ Persisted mask created by lightweight delete that show whether row exists or is deleted │ │ │ 1 │
│ _block_number │ UInt64 │ │ │ Persisted original number of block that was assigned at insert │ Delta, LZ4 │ │ 1 │
│ _table │ LowCardinality(String) │ │ │ Name of a table │ │ │ 1 │
│ _shard_num │ UInt32 │ │ │ Deprecated. Use function shardNum instead │ │ │ 1 │
└────────────────┴───────────────────────────┴──────────────┴────────────────────┴─────────────────────────────────────────────────────────────────────────────────────────┴──────────────────┴────────────────┴────────────┘
SET describe_compact_output = 0, describe_include_virtual_columns = 1, describe_include_subcolumns = 1;
@ -94,7 +93,6 @@ DESCRIBE remote(default, currentDatabase(), t_describe_options) FORMAT PrettyCom
│ _part_offset │ UInt64 │ │ │ Number of row in the part │ │ │ 0 │ 1 │
│ _row_exists │ UInt8 │ │ │ Persisted mask created by lightweight delete that show whether row exists or is deleted │ │ │ 0 │ 1 │
│ _block_number │ UInt64 │ │ │ Persisted original number of block that was assigned at insert │ Delta, LZ4 │ │ 0 │ 1 │
│ _table │ LowCardinality(String) │ │ │ Name of a table │ │ │ 0 │ 1 │
│ _shard_num │ UInt32 │ │ │ Deprecated. Use function shardNum instead │ │ │ 0 │ 1 │
│ arr.size0 │ UInt64 │ │ │ │ │ │ 1 │ 0 │
│ t.a │ String │ │ │ │ ZSTD(1) │ │ 1 │ 0 │
@ -160,7 +158,6 @@ DESCRIBE remote(default, currentDatabase(), t_describe_options) FORMAT PrettyCom
│ _part_offset │ UInt64 │ 1 │
│ _row_exists │ UInt8 │ 1 │
│ _block_number │ UInt64 │ 1 │
│ _table │ LowCardinality(String) │ 1 │
│ _shard_num │ UInt32 │ 1 │
└────────────────┴───────────────────────────┴────────────┘
SET describe_compact_output = 1, describe_include_virtual_columns = 1, describe_include_subcolumns = 1;
@ -194,7 +191,6 @@ DESCRIBE remote(default, currentDatabase(), t_describe_options) FORMAT PrettyCom
│ _part_offset │ UInt64 │ 0 │ 1 │
│ _row_exists │ UInt8 │ 0 │ 1 │
│ _block_number │ UInt64 │ 0 │ 1 │
│ _table │ LowCardinality(String) │ 0 │ 1 │
│ _shard_num │ UInt32 │ 0 │ 1 │
│ arr.size0 │ UInt64 │ 1 │ 0 │
│ t.a │ String │ 1 │ 0 │