Avoid dynamic_cast

This commit is contained in:
Dmitry Novik 2022-12-01 18:51:57 +00:00
parent 6bc152c7a0
commit 4b74337188
2 changed files with 36 additions and 26 deletions

View File

@ -37,42 +37,27 @@ ColumnsWithTypeAndName FunctionNode::getArgumentTypes() const
return argument_types; return argument_types;
} }
FunctionBasePtr FunctionNode::getFunction() const
{
return std::dynamic_pointer_cast<const IFunctionBase>(function);
}
AggregateFunctionPtr FunctionNode::getAggregateFunction() const
{
return std::dynamic_pointer_cast<const IAggregateFunction>(function);
}
bool FunctionNode::isAggregateFunction() const
{
return typeid_cast<AggregateFunctionPtr>(function) != nullptr && !isWindowFunction();
}
bool FunctionNode::isOrdinaryFunction() const
{
return typeid_cast<FunctionBasePtr>(function) != nullptr;
}
void FunctionNode::resolveAsFunction(FunctionBasePtr function_value) void FunctionNode::resolveAsFunction(FunctionBasePtr function_value)
{ {
function_name = function_value->getName(); function_name = function_value->getName();
function = std::move(function_value); function = std::move(function_value);
kind = FunctionKind::ORDINARY;
} }
void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value) void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value)
{ {
function_name = aggregate_function_value->getName(); function_name = aggregate_function_value->getName();
function = std::move(aggregate_function_value); function = std::move(aggregate_function_value);
kind = FunctionKind::AGGREGATE;
} }
void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value) void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value)
{ {
if (!hasWindow())
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Trying to resolve FunctionNode without window definition as a window function {}", window_function_value->getName());
resolveAsAggregateFunction(window_function_value); resolveAsAggregateFunction(window_function_value);
kind = FunctionKind::WINDOW;
} }
void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const

View File

@ -41,6 +41,14 @@ using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
class FunctionNode; class FunctionNode;
using FunctionNodePtr = std::shared_ptr<FunctionNode>; using FunctionNodePtr = std::shared_ptr<FunctionNode>;
enum class FunctionKind
{
UNKNOWN,
ORDINARY,
AGGREGATE,
WINDOW,
};
class FunctionNode final : public IQueryTreeNode class FunctionNode final : public IQueryTreeNode
{ {
public: public:
@ -133,13 +141,23 @@ public:
/** Get non aggregate function. /** Get non aggregate function.
* If function is not resolved nullptr returned. * If function is not resolved nullptr returned.
*/ */
FunctionBasePtr getFunction() const; FunctionBasePtr getFunction() const
{
if (kind != FunctionKind::ORDINARY)
return {};
return std::reinterpret_pointer_cast<const IFunctionBase>(function);
}
/** Get aggregate function. /** Get aggregate function.
* If function is not resolved nullptr returned. * If function is not resolved nullptr returned.
* If function is resolved as non aggregate function nullptr returned. * If function is resolved as non aggregate function nullptr returned.
*/ */
AggregateFunctionPtr getAggregateFunction() const; AggregateFunctionPtr getAggregateFunction() const
{
if (kind == FunctionKind::UNKNOWN || kind == FunctionKind::ORDINARY)
return {};
return std::reinterpret_pointer_cast<const IAggregateFunction>(function);
}
/// Is function node resolved /// Is function node resolved
bool isResolved() const bool isResolved() const
@ -150,14 +168,20 @@ public:
/// Is function node window function /// Is function node window function
bool isWindowFunction() const bool isWindowFunction() const
{ {
return getWindowNode() != nullptr; return kind == FunctionKind::WINDOW;
} }
/// Is function node aggregate function /// Is function node aggregate function
bool isAggregateFunction() const; bool isAggregateFunction() const
{
return kind == FunctionKind::AGGREGATE;
}
/// Is function node ordinary function /// Is function node ordinary function
bool isOrdinaryFunction() const; bool isOrdinaryFunction() const
{
return kind == FunctionKind::ORDINARY;
}
/** Resolve function node as non aggregate function. /** Resolve function node as non aggregate function.
* It is important that function name is updated with resolved function name. * It is important that function name is updated with resolved function name.
@ -202,6 +226,7 @@ protected:
private: private:
String function_name; String function_name;
FunctionKind kind = FunctionKind::UNKNOWN;
IResolvedFunctionPtr function; IResolvedFunctionPtr function;
static constexpr size_t parameters_child_index = 0; static constexpr size_t parameters_child_index = 0;