Revert "Revert "Validate function arguments in query tree""

This commit is contained in:
Maksim Kita 2023-01-16 10:15:44 +01:00
parent b13498d9ba
commit 250c93614c
3 changed files with 42 additions and 0 deletions

View File

@ -2,6 +2,7 @@
#include <Common/SipHash.h>
#include <Common/FieldVisitorToString.h>
#include <DataTypes/IDataType.h>
#include <Analyzer/ConstantNode.h>
#include <IO/WriteBufferFromString.h>
@ -31,6 +32,15 @@ FunctionNode::FunctionNode(String function_name_)
children[arguments_child_index] = std::make_shared<ListNode>();
}
const DataTypes & FunctionNode::getArgumentTypes() const
{
if (!function)
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Function {} is not resolved",
function_name);
return function->getArgumentTypes();
}
ColumnsWithTypeAndName FunctionNode::getArgumentColumns() const
{
const auto & arguments = getArguments().getNodes();

View File

@ -85,6 +85,7 @@ public:
/// Get arguments node
QueryTreeNodePtr & getArgumentsNode() { return children[arguments_child_index]; }
const DataTypes & getArgumentTypes() const;
ColumnsWithTypeAndName getArgumentColumns() const;
/// Returns true if function node has window, false otherwise

View File

@ -25,6 +25,7 @@
#include <Analyzer/FunctionNode.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Common/Exception.h>
#include <DataTypes/IDataType.h>
namespace DB
{
@ -61,6 +62,36 @@ class ValidationChecker : public InDepthQueryTreeVisitor<ValidationChecker>
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Function {} is not resolved after running {} pass",
function->toAST()->formatForErrorMessage(), pass_name);
if (function->getFunctionName() == "in")
return;
const auto & expected_arg_types = function->getArgumentTypes();
auto actual_arg_columns = function->getArgumentColumns();
if (expected_arg_types.size() != actual_arg_columns.size())
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Function {} expects {} arguments but has {} after running {} pass",
function->toAST()->formatForErrorMessage(),
expected_arg_types.size(),
actual_arg_columns.size(),
pass_name);
for (size_t i = 0; i < expected_arg_types.size(); ++i)
{
// Skip lambdas
if (WhichDataType(expected_arg_types[i]).isFunction())
continue;
if (!expected_arg_types[i]->equals(*actual_arg_columns[i].type))
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Function {} expects {} argument to have {} type but receives {} after running {} pass",
function->toAST()->formatForErrorMessage(),
i,
expected_arg_types[i]->getName(),
actual_arg_columns[i].type->getName(),
pass_name);
}
}
public: