diff --git a/src/Analyzer/FunctionNode.cpp b/src/Analyzer/FunctionNode.cpp index 36e4f1fc69b..323e6f02952 100644 --- a/src/Analyzer/FunctionNode.cpp +++ b/src/Analyzer/FunctionNode.cpp @@ -37,42 +37,27 @@ ColumnsWithTypeAndName FunctionNode::getArgumentTypes() const return argument_types; } - -FunctionBasePtr FunctionNode::getFunction() const -{ - return std::dynamic_pointer_cast(function); -} - -AggregateFunctionPtr FunctionNode::getAggregateFunction() const -{ - return std::dynamic_pointer_cast(function); -} - -bool FunctionNode::isAggregateFunction() const -{ - return typeid_cast(function) != nullptr && !isWindowFunction(); -} - -bool FunctionNode::isOrdinaryFunction() const -{ - return typeid_cast(function) != nullptr; -} - void FunctionNode::resolveAsFunction(FunctionBasePtr function_value) { function_name = function_value->getName(); function = std::move(function_value); + kind = FunctionKind::ORDINARY; } void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value) { function_name = aggregate_function_value->getName(); function = std::move(aggregate_function_value); + kind = FunctionKind::AGGREGATE; } void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value) { + if (!hasWindow()) + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Trying to resolve FunctionNode without window definition as a window function {}", window_function_value->getName()); resolveAsAggregateFunction(window_function_value); + kind = FunctionKind::WINDOW; } void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const diff --git a/src/Analyzer/FunctionNode.h b/src/Analyzer/FunctionNode.h index 3a8c5445375..feeecfdb146 100644 --- a/src/Analyzer/FunctionNode.h +++ b/src/Analyzer/FunctionNode.h @@ -41,6 +41,14 @@ using AggregateFunctionPtr = std::shared_ptr; class FunctionNode; using FunctionNodePtr = std::shared_ptr; +enum class FunctionKind +{ + UNKNOWN, + ORDINARY, + AGGREGATE, + WINDOW, +}; + class FunctionNode final : public IQueryTreeNode { public: @@ -133,13 +141,23 @@ public: /** Get non aggregate function. * If function is not resolved nullptr returned. */ - FunctionBasePtr getFunction() const; + FunctionBasePtr getFunction() const + { + if (kind != FunctionKind::ORDINARY) + return {}; + return std::reinterpret_pointer_cast(function); + } /** Get aggregate function. * If function is not resolved nullptr returned. * If function is resolved as non aggregate function nullptr returned. */ - AggregateFunctionPtr getAggregateFunction() const; + AggregateFunctionPtr getAggregateFunction() const + { + if (kind == FunctionKind::UNKNOWN || kind == FunctionKind::ORDINARY) + return {}; + return std::reinterpret_pointer_cast(function); + } /// Is function node resolved bool isResolved() const @@ -150,14 +168,20 @@ public: /// Is function node window function bool isWindowFunction() const { - return getWindowNode() != nullptr; + return kind == FunctionKind::WINDOW; } /// Is function node aggregate function - bool isAggregateFunction() const; + bool isAggregateFunction() const + { + return kind == FunctionKind::AGGREGATE; + } /// Is function node ordinary function - bool isOrdinaryFunction() const; + bool isOrdinaryFunction() const + { + return kind == FunctionKind::ORDINARY; + } /** Resolve function node as non aggregate function. * It is important that function name is updated with resolved function name. @@ -202,6 +226,7 @@ protected: private: String function_name; + FunctionKind kind = FunctionKind::UNKNOWN; IResolvedFunctionPtr function; static constexpr size_t parameters_child_index = 0;