This commit is contained in:
Dmitry Novik 2024-09-19 16:03:11 +00:00 committed by GitHub
commit 7ccbff7d17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 105 additions and 81 deletions

View File

@ -2,6 +2,7 @@
#include <Analyzer/FunctionNode.h>
#include <Columns/ColumnNullable.h>
#include <Common/assert_cast.h>
#include <Common/FieldVisitorToString.h>
#include <Common/SipHash.h>
@ -21,32 +22,44 @@
namespace DB
{
ConstantNode::ConstantNode(ConstantValuePtr constant_value_, QueryTreeNodePtr source_expression_)
ConstantNode::ConstantNode(ConstantValue constant_value_, QueryTreeNodePtr source_expression_)
: IQueryTreeNode(children_size)
, constant_value(std::move(constant_value_))
, value_string(applyVisitor(FieldVisitorToString(), constant_value->getValue()))
{
source_expression = std::move(source_expression_);
}
ConstantNode::ConstantNode(ConstantValuePtr constant_value_)
ConstantNode::ConstantNode(ConstantValue constant_value_)
: ConstantNode(constant_value_, nullptr /*source_expression*/)
{}
ConstantNode::ConstantNode(ColumnPtr constant_column_, DataTypePtr value_data_type_)
: ConstantNode(ConstantValue{std::move(constant_column_), value_data_type_})
{}
ConstantNode::ConstantNode(ColumnPtr constant_column_)
: ConstantNode(constant_column_, applyVisitor(FieldToDataType(), (*constant_column_)[0]))
{}
ConstantNode::ConstantNode(Field value_, DataTypePtr value_data_type_)
: ConstantNode(std::make_shared<ConstantValue>(convertFieldToTypeOrThrow(value_, *value_data_type_), value_data_type_))
: ConstantNode(ConstantValue{convertFieldToTypeOrThrow(value_, *value_data_type_), value_data_type_})
{}
ConstantNode::ConstantNode(Field value_)
: ConstantNode(value_, applyVisitor(FieldToDataType(), value_))
{}
String ConstantNode::getValueStringRepresentation() const
{
return applyVisitor(FieldVisitorToString(), getValue());
}
bool ConstantNode::requiresCastCall() const
{
const auto & constant_value_literal = constant_value->getValue();
const auto & constant_value_literal = getValue();
bool need_to_add_cast_function = false;
auto constant_value_literal_type = constant_value_literal.getType();
WhichDataType constant_value_type(constant_value->getType());
WhichDataType constant_value_type(constant_value.getType());
switch (constant_value_literal_type)
{
@ -116,9 +129,9 @@ void ConstantNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state
if (mask_id)
buffer << "[HIDDEN id: " << mask_id << "]";
else
buffer << constant_value->getValue().dump();
buffer << getValue().dump();
buffer << ", constant_value_type: " << constant_value->getType()->getName();
buffer << ", constant_value_type: " << constant_value.getType()->getName();
if (!mask_id && getSourceExpression())
{
@ -129,30 +142,30 @@ void ConstantNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state
void ConstantNode::convertToNullable()
{
constant_value = std::make_shared<ConstantValue>(constant_value->getValue(), makeNullableSafe(constant_value->getType()));
constant_value = { makeNullableSafe(constant_value.getColumn()), makeNullableSafe(constant_value.getType()) };
}
bool ConstantNode::isEqualImpl(const IQueryTreeNode & rhs, CompareOptions compare_options) const
{
const auto & rhs_typed = assert_cast<const ConstantNode &>(rhs);
if (value_string != rhs_typed.value_string || constant_value->getValue() != rhs_typed.constant_value->getValue())
const auto & column = constant_value.getColumn();
const auto & rhs_column = rhs_typed.constant_value.getColumn();
if (column->getDataType() != rhs_column->getDataType() || column->compareAt(0, 0, *rhs_column, 1) != 0)
return false;
return !compare_options.compare_types || constant_value->getType()->equals(*rhs_typed.constant_value->getType());
return !compare_options.compare_types || constant_value.getType()->equals(*rhs_typed.constant_value.getType());
}
void ConstantNode::updateTreeHashImpl(HashState & hash_state, CompareOptions compare_options) const
{
constant_value.getColumn()->updateHashFast(hash_state);
if (compare_options.compare_types)
{
auto type_name = constant_value->getType()->getName();
auto type_name = constant_value.getType()->getName();
hash_state.update(type_name.size());
hash_state.update(type_name);
}
hash_state.update(value_string.size());
hash_state.update(value_string);
}
QueryTreeNodePtr ConstantNode::cloneImpl() const
@ -162,8 +175,8 @@ QueryTreeNodePtr ConstantNode::cloneImpl() const
ASTPtr ConstantNode::toASTImpl(const ConvertToASTOptions & options) const
{
const auto & constant_value_literal = constant_value->getValue();
const auto & constant_value_type = constant_value->getType();
const auto constant_value_literal = getValue();
const auto & constant_value_type = constant_value.getType();
auto constant_value_ast = std::make_shared<ASTLiteral>(constant_value_literal);
if (!options.add_cast_for_constants)

View File

@ -4,6 +4,7 @@
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ConstantValue.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
namespace DB
@ -22,10 +23,19 @@ class ConstantNode final : public IQueryTreeNode
{
public:
/// Construct constant query tree node from constant value and source expression
explicit ConstantNode(ConstantValuePtr constant_value_, QueryTreeNodePtr source_expression);
explicit ConstantNode(ConstantValue constant_value_, QueryTreeNodePtr source_expression);
/// Construct constant query tree node from constant value
explicit ConstantNode(ConstantValuePtr constant_value_);
explicit ConstantNode(ConstantValue constant_value_);
/** Construct constant query tree node from column and data type.
*
* Throws exception if value cannot be converted to value data type.
*/
explicit ConstantNode(ColumnPtr constant_column_, DataTypePtr value_data_type_);
/// Construct constant query tree node from column, data type will be derived from field value
explicit ConstantNode(ColumnPtr constant_column_);
/** Construct constant query tree node from field and data type.
*
@ -37,16 +47,21 @@ public:
explicit ConstantNode(Field value_);
/// Get constant value
const Field & getValue() const
const ColumnPtr & getColumn() const
{
return constant_value->getValue();
return constant_value.getColumn();
}
/// Get constant value
Field getValue() const
{
Field out;
constant_value.getColumn()->get(0, out);
return out;
}
/// Get constant value string representation
const String & getValueStringRepresentation() const
{
return value_string;
}
String getValueStringRepresentation() const;
/// Returns true if constant node has source expression, false otherwise
bool hasSourceExpression() const
@ -73,7 +88,7 @@ public:
DataTypePtr getResultType() const override
{
return constant_value->getType();
return constant_value.getType();
}
/// Check if conversion to AST requires wrapping with _CAST function.
@ -101,8 +116,7 @@ protected:
ASTPtr toASTImpl(const ConvertToASTOptions & options) const override;
private:
ConstantValuePtr constant_value;
String value_string;
ConstantValue constant_value;
QueryTreeNodePtr source_expression;
size_t mask_id = 0;

View File

@ -1,28 +1,29 @@
#pragma once
#include <Columns/ColumnConst.h>
#include <Columns/IColumn.h>
#include <Core/Field.h>
#include <DataTypes/IDataType.h>
namespace DB
{
/** Immutable constant value representation during analysis stage.
* Some query nodes can be represented by constant (scalar subqueries, functions with constant arguments).
*/
class ConstantValue;
using ConstantValuePtr = std::shared_ptr<ConstantValue>;
class ConstantValue
{
public:
ConstantValue(Field value_, DataTypePtr data_type_)
: value(std::move(value_))
ConstantValue(ColumnPtr column_, DataTypePtr data_type_)
: column(wrapToColumnConst(column_))
, data_type(std::move(data_type_))
{}
const Field & getValue() const
ConstantValue(const Field & field_, DataTypePtr data_type_)
: column(data_type_->createColumnConst(1, field_))
, data_type(std::move(data_type_))
{}
const ColumnPtr & getColumn() const
{
return value;
return column;
}
const DataTypePtr & getType() const
@ -30,7 +31,15 @@ public:
return data_type;
}
private:
Field value;
static ColumnPtr wrapToColumnConst(ColumnPtr column_)
{
if (!isColumnConst(*column_))
return ColumnConst::create(column_, 1);
return column_;
}
ColumnPtr column;
DataTypePtr data_type;
};

View File

@ -75,7 +75,7 @@ ColumnsWithTypeAndName FunctionNode::getArgumentColumns() const
argument_column.type = argument->getResultType();
if (constant && !isNotCreatable(argument_column.type))
argument_column.column = argument_column.type->createColumnConst(1, constant->getValue());
argument_column.column = constant->getColumn();
argument_columns.push_back(std::move(argument_column));
}

View File

@ -155,8 +155,7 @@ private:
if (function_arguments_nodes_size == 1)
{
auto comparison_argument_constant_value = std::make_shared<ConstantValue>(constant_tuple[0], tuple_data_type_elements[0]);
auto comparison_argument_constant_node = std::make_shared<ConstantNode>(std::move(comparison_argument_constant_value));
auto comparison_argument_constant_node = std::make_shared<ConstantNode>(constant_tuple[0], tuple_data_type_elements[0]);
return makeComparisonFunction(function_arguments_nodes[0], std::move(comparison_argument_constant_node), comparison_function_name);
}
@ -165,8 +164,7 @@ private:
for (size_t i = 0; i < function_arguments_nodes_size; ++i)
{
auto equals_argument_constant_value = std::make_shared<ConstantValue>(constant_tuple[i], tuple_data_type_elements[i]);
auto equals_argument_constant_node = std::make_shared<ConstantNode>(std::move(equals_argument_constant_value));
auto equals_argument_constant_node = std::make_shared<ConstantNode>(constant_tuple[i], tuple_data_type_elements[i]);
auto equals_function = makeEqualsFunction(function_arguments_nodes[i], std::move(equals_argument_constant_node));
tuple_arguments_equals_functions.push_back(std::move(equals_function));
}

View File

@ -498,8 +498,7 @@ private:
if (collapse_to_false)
{
auto false_value = std::make_shared<ConstantValue>(0u, function_node.getResultType());
auto false_node = std::make_shared<ConstantNode>(std::move(false_value));
auto false_node = std::make_shared<ConstantNode>(0u, function_node.getResultType());
node = std::move(false_node);
return;
}

View File

@ -341,11 +341,11 @@ static FunctionNodePtr wrapExpressionNodeInFunctionWithSecondConstantStringArgum
auto function_node = std::make_shared<FunctionNode>(std::move(function_name));
auto constant_node_type = std::make_shared<DataTypeString>();
auto constant_value = std::make_shared<ConstantValue>(std::move(second_argument), std::move(constant_node_type));
auto constant_value = ConstantValue{second_argument, std::move(constant_node_type)};
ColumnsWithTypeAndName argument_columns;
argument_columns.push_back({nullptr, expression->getResultType(), {}});
argument_columns.push_back({constant_value->getType()->createColumnConst(1, constant_value->getValue()), constant_value->getType(), {}});
argument_columns.push_back({constant_value.getColumn(), constant_value.getType(), {}});
auto function = FunctionFactory::instance().tryGet(function_node->getFunctionName(), context);
auto function_base = function->build(argument_columns);

View File

@ -648,9 +648,6 @@ void QueryAnalyzer::evaluateScalarSubqueryIfNeeded(QueryTreeNodePtr & node, Iden
const auto & scalar_column_with_type = scalar_block.safeGetByPosition(0);
const auto & scalar_type = scalar_column_with_type.type;
Field scalar_value;
scalar_column_with_type.column->get(0, scalar_value);
const auto * scalar_type_name = scalar_block.safeGetByPosition(0).type->getFamilyName();
static const std::set<std::string_view> useless_literal_types = {"Array", "Tuple", "AggregateFunction", "Function", "Set", "LowCardinality"};
auto * nearest_query_scope = scope.getNearestQueryScope();
@ -659,10 +656,10 @@ void QueryAnalyzer::evaluateScalarSubqueryIfNeeded(QueryTreeNodePtr & node, Iden
if (!context->getSettingsRef()[Setting::enable_scalar_subquery_optimization] || !useless_literal_types.contains(scalar_type_name)
|| !context->hasQueryContext() || !nearest_query_scope)
{
auto constant_value = std::make_shared<ConstantValue>(std::move(scalar_value), scalar_type);
ConstantValue constant_value{ scalar_column_with_type.column, scalar_type };
auto constant_node = std::make_shared<ConstantNode>(constant_value, node);
if (constant_node->getValue().isNull())
if (scalar_column_with_type.column->isNullAt(0))
{
node = buildCastFunction(constant_node, constant_node->getResultType(), context);
node = std::make_shared<ConstantNode>(std::move(constant_value), node);
@ -685,8 +682,7 @@ void QueryAnalyzer::evaluateScalarSubqueryIfNeeded(QueryTreeNodePtr & node, Iden
std::string get_scalar_function_name = "__getScalar";
auto scalar_query_hash_constant_value = std::make_shared<ConstantValue>(std::move(scalar_query_hash_string), std::make_shared<DataTypeString>());
auto scalar_query_hash_constant_node = std::make_shared<ConstantNode>(std::move(scalar_query_hash_constant_value));
auto scalar_query_hash_constant_node = std::make_shared<ConstantNode>(std::move(scalar_query_hash_string), std::make_shared<DataTypeString>());
auto get_scalar_function_node = std::make_shared<FunctionNode>(get_scalar_function_name);
get_scalar_function_node->getArguments().getNodes().push_back(std::move(scalar_query_hash_constant_node));
@ -828,8 +824,7 @@ void QueryAnalyzer::convertLimitOffsetExpression(QueryTreeNodePtr & expression_n
"{} numeric constant expression is not representable as UInt64",
expression_description);
auto constant_value = std::make_shared<ConstantValue>(std::move(converted_value), std::make_shared<DataTypeUInt64>());
auto result_constant_node = std::make_shared<ConstantNode>(std::move(constant_value));
auto result_constant_node = std::make_shared<ConstantNode>(std::move(converted_value), std::make_shared<DataTypeUInt64>());
result_constant_node->getSourceExpression() = limit_offset_constant_node->getSourceExpression();
expression_node = std::move(result_constant_node);
@ -3008,7 +3003,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
const auto * constant_node = function_argument->as<ConstantNode>();
if (constant_node)
{
argument_column.column = constant_node->getResultType()->createColumnConst(1, constant_node->getValue());
argument_column.column = constant_node->getColumn();
argument_column.type = constant_node->getResultType();
argument_is_constant = true;
}
@ -3412,7 +3407,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
if (first_argument_constant_node && second_argument_constant_node)
{
const auto & first_argument_constant_type = first_argument_constant_node->getResultType();
const auto & second_argument_constant_literal = second_argument_constant_node->getValue();
const auto second_argument_constant_literal = second_argument_constant_node->getValue();
const auto & second_argument_constant_type = second_argument_constant_node->getResultType();
const auto & settings = scope.context->getSettingsRef();
@ -3439,7 +3434,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
argument_columns[1].type = std::make_shared<DataTypeSet>();
}
std::shared_ptr<ConstantValue> constant_value;
ConstantNodePtr constant_node;
try
{
@ -3497,9 +3492,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
column->byteSize() < 1_MiB)
{
/// Replace function node with result constant node
Field column_constant_value;
column->get(0, column_constant_value);
constant_value = std::make_shared<ConstantValue>(std::move(column_constant_value), result_type);
constant_node = std::make_shared<ConstantNode>(ConstantValue{ std::move(column), std::move(result_type) }, node);
}
}
@ -3511,8 +3504,8 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
throw;
}
if (constant_value)
node = std::make_shared<ConstantNode>(std::move(constant_value), node);
if (constant_node)
node = std::move(constant_node);
return result_projection_names;
}

View File

@ -210,8 +210,7 @@ QueryTreeNodePtr buildCastFunction(const QueryTreeNodePtr & expression,
bool resolve)
{
std::string cast_type = type->getName();
auto cast_type_constant_value = std::make_shared<ConstantValue>(std::move(cast_type), std::make_shared<DataTypeString>());
auto cast_type_constant_node = std::make_shared<ConstantNode>(std::move(cast_type_constant_value));
auto cast_type_constant_node = std::make_shared<ConstantNode>(std::move(cast_type), std::make_shared<DataTypeString>());
std::string cast_function_name = "_CAST";
auto cast_function_node = std::make_shared<FunctionNode>(cast_function_name);
@ -787,8 +786,7 @@ NameSet collectIdentifiersFullNames(const QueryTreeNodePtr & node)
QueryTreeNodePtr createCastFunction(QueryTreeNodePtr node, DataTypePtr result_type, ContextPtr context)
{
auto enum_literal = std::make_shared<ConstantValue>(result_type->getName(), std::make_shared<DataTypeString>());
auto enum_literal_node = std::make_shared<ConstantNode>(std::move(enum_literal));
auto enum_literal_node = std::make_shared<ConstantNode>(result_type->getName(), std::make_shared<DataTypeString>());
auto cast_function = FunctionFactory::instance().get("_CAST", std::move(context));
QueryTreeNodes arguments{ std::move(node), std::move(enum_literal_node) };

View File

@ -88,26 +88,26 @@ std::string functionName(const ASTPtr & node)
return node->as<ASTFunction &>().name;
}
const Field * tryGetConstantValue(const QueryTreeNodePtr & node)
std::optional<Field> tryGetConstantValue(const QueryTreeNodePtr & node)
{
if (const auto * constant = node->as<ConstantNode>())
return &constant->getValue();
return constant->getValue();
return nullptr;
return {};
}
const Field * tryGetConstantValue(const ASTPtr & node)
std::optional<Field> tryGetConstantValue(const ASTPtr & node)
{
if (const auto * constant = node->as<ASTLiteral>())
return &constant->value;
return nullptr;
return {};
}
template <typename Node>
const Field & getConstantValue(const Node & node)
Field getConstantValue(const Node & node)
{
const auto * constant = tryGetConstantValue(node);
const auto constant = tryGetConstantValue(node);
assert(constant);
return *constant;
}
@ -518,7 +518,7 @@ void ComparisonGraph<Node>::EqualComponent::buildConstants()
constant_index.reset();
for (size_t i = 0; i < nodes.size(); ++i)
{
if (tryGetConstantValue(nodes[i]) != nullptr)
if (tryGetConstantValue(nodes[i]))
{
constant_index = i;
return;
@ -566,7 +566,7 @@ std::optional<Node> ComparisonGraph<Node>::getEqualConst(const Node & node) cons
template <ComparisonGraphNodeType Node>
std::optional<std::pair<Field, bool>> ComparisonGraph<Node>::getConstUpperBound(const Node & node) const
{
if (const auto * constant = tryGetConstantValue(node))
if (const auto constant = tryGetConstantValue(node))
return std::make_pair(*constant, false);
const auto it = graph.node_hash_to_component.find(Graph::getHash(node));
@ -584,7 +584,7 @@ std::optional<std::pair<Field, bool>> ComparisonGraph<Node>::getConstUpperBound(
template <ComparisonGraphNodeType Node>
std::optional<std::pair<Field, bool>> ComparisonGraph<Node>::getConstLowerBound(const Node & node) const
{
if (const auto * constant = tryGetConstantValue(node))
if (const auto constant = tryGetConstantValue(node))
return std::make_pair(*constant, false);
const auto it = graph.node_hash_to_component.find(Graph::getHash(node));

View File

@ -168,7 +168,7 @@ public:
{
if (isTuple(constant->getResultType()))
{
const auto & tuple = constant->getValue().safeGet<Tuple &>();
const auto tuple = constant->getValue().safeGet<Tuple>();
Tuple new_tuple;
new_tuple.reserve(tuple.size());

View File

@ -49,7 +49,7 @@ public:
WriteBufferFromOwnString out;
result_type->getDefaultSerialization()->serializeText(inner_column, 0, out, FormatSettings());
node = std::make_shared<ConstantNode>(std::make_shared<ConstantValue>(out.str(), result_type));
node = std::make_shared<ConstantNode>(out.str(), std::move(result_type));
}
}
}