Cleanup code

This commit is contained in:
Dmitry Novik 2022-12-23 18:23:01 +00:00
parent 9da949cdb1
commit e481c0bae5
18 changed files with 59 additions and 116 deletions

View File

@ -116,13 +116,9 @@ public:
static DataTypePtr getKeyType(const DataTypes & types, const AggregateFunctionPtr & nested)
{
if (types.empty())
if (types.size() != 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {}Map requires at least one argument", nested->getName());
if (types.size() > 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {}Map requires only one map argument", nested->getName());
"Aggregate function {}Map requires one map argument, but {} found", nested->getName(), types.size());
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
if (!map_type)

View File

@ -36,8 +36,8 @@ public:
AggregateFunctionOrFill(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params, createResultType(nested_function_->getResultType())}
, nested_function{nested_function_}
, size_of_data {nested_function->sizeOfData()}
, inner_nullable {nested_function->getResultType()->isNullable()}
, size_of_data{nested_function->sizeOfData()}
, inner_nullable{nested_function->getResultType()->isNullable()}
{
// nothing
}

View File

@ -428,10 +428,7 @@ public:
}
bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
String getName() const override { return getNameImpl(); }
private:
static String getNameImpl() { return Derived::getNameImpl(); }
String getName() const override { return Derived::getNameImpl(); }
};
template <typename T, bool overflow, bool tuple_argument>

View File

@ -65,9 +65,9 @@ class IAggregateFunction : public std::enable_shared_from_this<IAggregateFunctio
{
public:
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
: result_type(result_type_)
, argument_types(argument_types_)
: argument_types(argument_types_)
, parameters(parameters_)
, result_type(result_type_)
{}
/// Get main function name.
@ -401,9 +401,9 @@ public:
#endif
protected:
DataTypePtr result_type;
DataTypes argument_types;
Array parameters;
DataTypePtr result_type;
};

View File

@ -31,18 +31,21 @@ FunctionNode::FunctionNode(String function_name_)
children[arguments_child_index] = std::make_shared<ListNode>();
}
ColumnsWithTypeAndName FunctionNode::getArgumentTypes() const
ColumnsWithTypeAndName FunctionNode::getArgumentColumns() const
{
ColumnsWithTypeAndName argument_types;
for (const auto & arg : getArguments().getNodes())
const auto & arguments = getArguments().getNodes();
ColumnsWithTypeAndName argument_columns;
argument_columns.reserve(arguments.size());
for (const auto & arg : arguments)
{
ColumnWithTypeAndName argument;
argument.type = arg->getResultType();
if (auto * constant = arg->as<ConstantNode>())
argument.column = argument.type->createColumnConst(1, constant->getValue());
argument_types.push_back(argument);
argument_columns.push_back(argument);
}
return argument_types;
return argument_columns;
}
void FunctionNode::resolveAsFunction(FunctionBasePtr function_value)

View File

@ -1,12 +1,14 @@
#pragma once
#include <memory>
#include <Core/IResolvedFunction.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Analyzer/ConstantValue.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Analyzer/ConstantValue.h>
#include <Common/typeid_cast.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <Core/IResolvedFunction.h>
#include <Common/typeid_cast.h>
#include <Functions/IFunction.h>
namespace DB
{
@ -19,12 +21,6 @@ namespace ErrorCodes
class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<const IFunctionBase>;
class IAggregateFunction;
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
/** Function node represents function in query tree.
* Function syntax: function_name(parameter_1, ...)(argument_1, ...).
* If function does not have parameters its syntax is function_name(argument_1, ...).
@ -63,66 +59,36 @@ public:
explicit FunctionNode(String function_name_);
/// Get function name
const String & getFunctionName() const
{
return function_name;
}
const String & getFunctionName() const { return function_name; }
/// Get parameters
const ListNode & getParameters() const
{
return children[parameters_child_index]->as<const ListNode &>();
}
const ListNode & getParameters() const { return children[parameters_child_index]->as<const ListNode &>(); }
/// Get parameters
ListNode & getParameters()
{
return children[parameters_child_index]->as<ListNode &>();
}
ListNode & getParameters() { return children[parameters_child_index]->as<ListNode &>(); }
/// Get parameters node
const QueryTreeNodePtr & getParametersNode() const
{
return children[parameters_child_index];
}
const QueryTreeNodePtr & getParametersNode() const { return children[parameters_child_index]; }
/// Get parameters node
QueryTreeNodePtr & getParametersNode()
{
return children[parameters_child_index];
}
QueryTreeNodePtr & getParametersNode() { return children[parameters_child_index]; }
/// Get arguments
const ListNode & getArguments() const
{
return children[arguments_child_index]->as<const ListNode &>();
}
const ListNode & getArguments() const { return children[arguments_child_index]->as<const ListNode &>(); }
/// Get arguments
ListNode & getArguments()
{
return children[arguments_child_index]->as<ListNode &>();
}
ListNode & getArguments() { return children[arguments_child_index]->as<ListNode &>(); }
/// Get arguments node
const QueryTreeNodePtr & getArgumentsNode() const
{
return children[arguments_child_index];
}
const QueryTreeNodePtr & getArgumentsNode() const { return children[arguments_child_index]; }
/// Get arguments node
QueryTreeNodePtr & getArgumentsNode()
{
return children[arguments_child_index];
}
QueryTreeNodePtr & getArgumentsNode() { return children[arguments_child_index]; }
ColumnsWithTypeAndName getArgumentTypes() const;
ColumnsWithTypeAndName getArgumentColumns() const;
/// Returns true if function node has window, false otherwise
bool hasWindow() const
{
return children[window_child_index] != nullptr;
}
bool hasWindow() const { return children[window_child_index] != nullptr; }
/** Get window node.
* Valid only for window function node.
@ -130,18 +96,12 @@ public:
* 1. It can be identifier node if window function is defined as expr OVER window_name.
* 2. It can be window node if window function is defined as expr OVER (window_name ...).
*/
const QueryTreeNodePtr & getWindowNode() const
{
return children[window_child_index];
}
const QueryTreeNodePtr & getWindowNode() const { return children[window_child_index]; }
/** Get window node.
* Valid only for window function node.
*/
QueryTreeNodePtr & getWindowNode()
{
return children[window_child_index];
}
QueryTreeNodePtr & getWindowNode() { return children[window_child_index]; }
/** Get non aggregate function.
* If function is not resolved nullptr returned.
@ -150,7 +110,7 @@ public:
{
if (kind != FunctionKind::ORDINARY)
return {};
return std::reinterpret_pointer_cast<const IFunctionBase>(function);
return std::static_pointer_cast<const IFunctionBase>(function);
}
/** Get aggregate function.
@ -161,32 +121,20 @@ public:
{
if (kind == FunctionKind::UNKNOWN || kind == FunctionKind::ORDINARY)
return {};
return std::reinterpret_pointer_cast<const IAggregateFunction>(function);
return std::static_pointer_cast<const IAggregateFunction>(function);
}
/// Is function node resolved
bool isResolved() const
{
return function != nullptr;
}
bool isResolved() const { return function != nullptr; }
/// Is function node window function
bool isWindowFunction() const
{
return hasWindow();
}
bool isWindowFunction() const { return hasWindow(); }
/// Is function node aggregate function
bool isAggregateFunction() const
{
return kind == FunctionKind::AGGREGATE;
}
bool isAggregateFunction() const { return kind == FunctionKind::AGGREGATE; }
/// Is function node ordinary function
bool isOrdinaryFunction() const
{
return kind == FunctionKind::ORDINARY;
}
bool isOrdinaryFunction() const { return kind == FunctionKind::ORDINARY; }
/** Resolve function node as non aggregate function.
* It is important that function name is updated with resolved function name.
@ -208,10 +156,7 @@ public:
*/
void resolveAsWindowFunction(AggregateFunctionPtr window_function_value);
QueryTreeNodeType getNodeType() const override
{
return QueryTreeNodeType::FUNCTION;
}
QueryTreeNodeType getNodeType() const override { return QueryTreeNodeType::FUNCTION; }
DataTypePtr getResultType() const override
{

View File

@ -155,7 +155,7 @@ public:
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{
auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
function_node.resolveAsFunction(function->build(function_node.getArgumentColumns()));
}
private:

View File

@ -193,7 +193,7 @@ private:
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{
auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
function_node.resolveAsFunction(function->build(function_node.getArgumentColumns()));
}
ContextPtr & context;

View File

@ -65,7 +65,7 @@ QueryTreeNodePtr createResolvedFunction(const ContextPtr & context, const String
auto function = FunctionFactory::instance().get(name, context);
function_node->getArguments().getNodes() = std::move(arguments);
function_node->resolveAsFunction(function->build(function_node->getArgumentTypes()));
function_node->resolveAsFunction(function->build(function_node->getArgumentColumns()));
return function_node;
}
@ -88,7 +88,7 @@ FunctionNodePtr createResolvedAggregateFunction(const String & name, const Query
{ argument->getResultType() },
parameters,
properties);
function_node->resolveAsAggregateFunction(aggregate_function);
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
return function_node;
}

View File

@ -56,7 +56,7 @@ public:
auto multi_if_function = std::make_shared<FunctionNode>("multiIf");
multi_if_function->getArguments().getNodes() = std::move(multi_if_arguments);
multi_if_function->resolveAsFunction(multi_if_function_ptr->build(multi_if_function->getArgumentTypes()));
multi_if_function->resolveAsFunction(multi_if_function_ptr->build(multi_if_function->getArgumentColumns()));
node = std::move(multi_if_function);
}

View File

@ -52,7 +52,7 @@ QueryTreeNodePtr createCastFunction(QueryTreeNodePtr from, DataTypePtr result_ty
auto function_node = std::make_shared<FunctionNode>("_CAST");
function_node->getArguments().getNodes() = std::move(arguments);
function_node->resolveAsFunction(cast_function->build(function_node->getArgumentTypes()));
function_node->resolveAsFunction(cast_function->build(function_node->getArgumentColumns()));
return function_node;
}
@ -71,7 +71,7 @@ void changeIfArguments(
auto if_resolver = FunctionFactory::instance().get("if", context);
if_node.resolveAsFunction(if_resolver->build(if_node.getArgumentTypes()));
if_node.resolveAsFunction(if_resolver->build(if_node.getArgumentColumns()));
}
/// transform(value, array_from, array_to, default_value) will be transformed to transform(value, array_from, _CAST(array_to, Array(Enum...)), _CAST(default_value, Enum...))
@ -93,7 +93,7 @@ void changeTransformArguments(
auto transform_resolver = FunctionFactory::instance().get("transform", context);
transform_node.resolveAsFunction(transform_resolver->build(transform_node.getArgumentTypes()));
transform_node.resolveAsFunction(transform_resolver->build(transform_node.getArgumentColumns()));
}
void wrapIntoToString(FunctionNode & function_node, QueryTreeNodePtr arg, ContextPtr context)
@ -102,7 +102,7 @@ void wrapIntoToString(FunctionNode & function_node, QueryTreeNodePtr arg, Contex
QueryTreeNodes arguments{ std::move(arg) };
function_node.getArguments().getNodes() = std::move(arguments);
function_node.resolveAsFunction(to_string_function->build(function_node.getArgumentTypes()));
function_node.resolveAsFunction(to_string_function->build(function_node.getArgumentColumns()));
assert(isString(function_node.getResultType()));
}

View File

@ -27,7 +27,7 @@ public:
return;
auto result_type = function_node->getResultType();
function_node->resolveAsFunction(if_function_ptr->build(function_node->getArgumentTypes()));
function_node->resolveAsFunction(if_function_ptr->build(function_node->getArgumentColumns()));
}
private:

View File

@ -4327,7 +4327,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
function_node.resolveAsWindowFunction(aggregate_function);
function_node.resolveAsWindowFunction(std::move(aggregate_function));
bool window_node_is_identifier = function_node.getWindowNode()->getNodeType() == QueryTreeNodeType::IDENTIFIER;
ProjectionName window_projection_name = resolveWindow(function_node.getWindowNode(), scope);
@ -4386,7 +4386,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
function_node.resolveAsAggregateFunction(aggregate_function);
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
return result_projection_names;
}

View File

@ -121,7 +121,7 @@ public:
auto & not_function_arguments = not_function->getArguments().getNodes();
not_function_arguments.push_back(std::move(nested_if_function_arguments_nodes[0]));
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentTypes()));
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentColumns()));
function_node_arguments_nodes[0] = std::move(not_function);
function_node_arguments_nodes.resize(1);

View File

@ -75,7 +75,6 @@ public:
function_node->getAggregateFunction()->getParameters(),
properties);
auto function_result_type = function_node->getResultType();
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
}
};

View File

@ -59,7 +59,7 @@ class ValidationChecker : public InDepthQueryTreeVisitor<ValidationChecker>
if (!function->isResolved())
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Function {} is not resolved after running {} pass",
function->dumpTree(), pass_name);
function->toAST()->dumpTree(), pass_name);
}
public:

View File

@ -12,6 +12,9 @@ using DataTypes = std::vector<DataTypePtr>;
struct Array;
/* Generic class for all functions.
* Represents interface for function signature.
*/
class IResolvedFunction
{
public:

View File

@ -350,7 +350,7 @@ void Planner::buildQueryPlanIfNeeded()
auto function_node = std::make_shared<FunctionNode>("and");
auto and_function = FunctionFactory::instance().get("and", query_context);
function_node->getArguments().getNodes() = {query_node.getPrewhere(), query_node.getWhere()};
function_node->resolveAsFunction(and_function->build(function_node->getArgumentTypes()));
function_node->resolveAsFunction(and_function->build(function_node->getArgumentColumns()));
query_node.getWhere() = std::move(function_node);
query_node.getPrewhere() = {};
}