Allow constant folding throught __getScalar

This commit is contained in:
Nikolai Kochetov 2024-04-11 15:25:52 +00:00
parent 069fb3d42e
commit 3e16309e99
3 changed files with 25 additions and 2 deletions

View File

@ -5624,17 +5624,35 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
function_name,
scope.scope_node->formatASTForErrorMessage());
bool argument_is_constant = false;
const auto * constant_node = function_argument->as<ConstantNode>();
if (constant_node)
{
argument_column.column = constant_node->getResultType()->createColumnConst(1, constant_node->getValue());
argument_column.type = constant_node->getResultType();
argument_is_constant = true;
}
else
else if(const auto * get_scalar_function_node = function_argument->as<FunctionNode>();
get_scalar_function_node && get_scalar_function_node->getFunctionName() == "__getScalar")
{
all_arguments_constants = false;
/// Allow constant folding through getScalar
const auto * get_scalar_const_arg = get_scalar_function_node->getArguments().getNodes().at(0)->as<ConstantNode>();
if (get_scalar_const_arg && scope.context->hasQueryContext())
{
auto query_context = scope.context->getQueryContext();
auto scalar_string = toString(get_scalar_const_arg->getValue());
if (query_context->hasScalar(scalar_string))
{
auto scalar = query_context->getScalar(scalar_string);
argument_column.column = ColumnConst::create(scalar.getByPosition(0).column, 1);
argument_column.type = get_scalar_function_node->getResultType();
argument_is_constant = true;
}
}
}
all_arguments_constants &= argument_is_constant;
argument_types.push_back(argument_column.type);
argument_columns.emplace_back(std::move(argument_column));
}

View File

@ -24,3 +24,5 @@ SELECT * FROM numbers(10) LIMIT LENGTH('NNN') + COS(0), toDate('0000-00-02'); --
SELECT * FROM numbers(10) LIMIT a + 5 - a; -- { serverError 47 }
SELECT * FROM numbers(10) LIMIT a + b; -- { serverError 47 }
SELECT * FROM numbers(10) LIMIT 'Hello'; -- { serverError 440 }
SELECT number from numbers(10) order by number limit (select sum(number), count() from numbers(3)).1;