diff --git a/dbms/src/Functions/FunctionsNull.cpp b/dbms/src/Functions/FunctionsNull.cpp index 131bf4280c3..09f583fdcf4 100644 --- a/dbms/src/Functions/FunctionsNull.cpp +++ b/dbms/src/Functions/FunctionsNull.cpp @@ -190,6 +190,15 @@ std::string FunctionIfNull::getName() const return name; } + +static const DataTypePtr getNestedDataType(const DataTypePtr & type) +{ + if (type->isNullable()) + return static_cast(*type).getNestedType(); + + return type; +} + bool FunctionIfNull::hasSpecialSupportForNulls() const { return true; @@ -197,20 +206,29 @@ bool FunctionIfNull::hasSpecialSupportForNulls() const DataTypePtr FunctionIfNull::getReturnTypeImpl(const DataTypes & arguments) const { - return FunctionIf{}.getReturnTypeImpl({std::make_shared(), arguments[0], arguments[1]}); + if (arguments[0]->isNull()) + return arguments[1]; + + return FunctionIf{}.getReturnTypeImpl({std::make_shared(), getNestedDataType(arguments[0]), arguments[1]}); } void FunctionIfNull::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) { - /// ifNull(col1, col2) == multiIf(isNotNull(col1), col1, col2) + if (block.getByPosition(arguments[0]).column->isNull()) + block.getByPosition(result).column = block.getByPosition(arguments[1]).column; + + /// ifNull(col1, col2) == if(isNotNull(col1), assumeNotNull(col1), col2) Block temp_block = block; - size_t res_pos = temp_block.columns(); + size_t is_not_null_pos = temp_block.columns(); temp_block.insert({nullptr, std::make_shared(), ""}); + size_t assume_not_null_pos = temp_block.columns(); + temp_block.insert({nullptr, getNestedDataType(block.getByPosition(arguments[0]).type), ""}); - FunctionIsNotNull{}.executeImpl(temp_block, {arguments[0]}, res_pos); - FunctionIf{}.executeImpl(temp_block, {res_pos, arguments[0], arguments[1]}, result); + FunctionIsNotNull{}.executeImpl(temp_block, {arguments[0]}, is_not_null_pos); + FunctionAssumeNotNull{}.executeImpl(temp_block, {arguments[0]}, assume_not_null_pos); + FunctionIf{}.executeImpl(temp_block, {is_not_null_pos, assume_not_null_pos, arguments[1]}, result); block.safeGetByPosition(result).column = std::move(temp_block.safeGetByPosition(result).column); } @@ -285,13 +303,7 @@ DataTypePtr FunctionAssumeNotNull::getReturnTypeImpl(const DataTypes & arguments { if (arguments[0]->isNull()) throw Exception{"NULL is an invalid value for function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - else if (arguments[0]->isNullable()) - { - const DataTypeNullable & nullable_type = static_cast(*arguments[0]); - return nullable_type.getNestedType(); - } - else - return arguments[0]; + return getNestedDataType(arguments[0]); } void FunctionAssumeNotNull::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result)