diff --git a/dbms/include/DB/Functions/IFunction.h b/dbms/include/DB/Functions/IFunction.h index 2694f0a55f2..45660b2258b 100644 --- a/dbms/include/DB/Functions/IFunction.h +++ b/dbms/include/DB/Functions/IFunction.h @@ -141,7 +141,7 @@ protected: /// Returns the copy of a given block in which each column specified in /// the "arguments" parameter is replaced with its respective nested /// column if it is nullable. - static Block extractNonNullableBlock(const Block & block, const ColumnNumbers & arguments); + static Block extractNonNullableBlock(const Block & block, const ColumnNumbers args); private: /// Internal method used for implementing both the execute() methods. diff --git a/dbms/src/Functions/FunctionsConditional.cpp b/dbms/src/Functions/FunctionsConditional.cpp index 5a2dc34781d..fb044c2aba2 100644 --- a/dbms/src/Functions/FunctionsConditional.cpp +++ b/dbms/src/Functions/FunctionsConditional.cpp @@ -208,22 +208,22 @@ void FunctionMultiIf::executeImpl(Block & block, const ColumnNumbers & args, siz { /// Keep track of which columns are nullable. std::vector nullable_cols_map; - nullable_cols_map.reserve(args.size()); + nullable_cols_map.resize(args.size()); for (const auto & arg : args) { const auto & col = block.unsafeGetByPosition(arg).column; bool may_have_null = col->isNullable(); - nullable_cols_map.push_back(static_cast(may_have_null)); + nullable_cols_map[arg] = may_have_null ? 1 : 0; } /// Keep track of which columns are null. std::vector null_cols_map; - null_cols_map.reserve(args.size()); + null_cols_map.resize(args.size()); for (const auto & arg : args) { const auto & col = block.unsafeGetByPosition(arg).column; bool has_null = col->isNull(); - null_cols_map.push_back(static_cast(has_null)); + null_cols_map[arg] = has_null ? 1 : 0; } auto null_map = std::make_shared(row_count); diff --git a/dbms/src/Functions/IFunction.cpp b/dbms/src/Functions/IFunction.cpp index 0808d4b99ad..736801b131d 100644 --- a/dbms/src/Functions/IFunction.cpp +++ b/dbms/src/Functions/IFunction.cpp @@ -11,20 +11,20 @@ namespace DB namespace { -void createNullValuesByteMap(Block & block, size_t result) +void createNullValuesByteMap(Block & block, const ColumnNumbers & args, size_t result) { ColumnNullable & res_col = static_cast(*block.unsafeGetByPosition(result).column); - for (size_t i = 0; i < block.columns(); ++i) + for (const auto & arg : args) { - if (i == result) + if (arg == result) continue; - const ColumnWithTypeAndName & elem = block.unsafeGetByPosition(i); + const ColumnWithTypeAndName & elem = block.unsafeGetByPosition(arg); if (elem.column && elem.column.get()->isNullable()) { - const ColumnNullable & concrete_col = static_cast(*elem.column); - res_col.updateNullValuesByteMap(concrete_col); + const ColumnNullable & nullable_col = static_cast(*elem.column); + res_col.updateNullValuesByteMap(nullable_col); } } } @@ -198,19 +198,19 @@ void IFunction::getLambdaArgumentTypes(DataTypes & arguments) const getLambdaArgumentTypesImpl(arguments); } -Block IFunction::extractNonNullableBlock(const Block & block, const ColumnNumbers & arguments) +/// Return a copy of a given block in which the specified columns are replaced by +/// their respective nested columns if they are nullable. +Block IFunction::extractNonNullableBlock(const Block & block, const ColumnNumbers args) { + std::sort(args.begin(), args.end()); + Block non_nullable_block; - ColumnNumbers args2 = arguments; - std::sort(args2.begin(), args2.end()); - - size_t pos = 0; for (size_t i = 0; i < block.columns(); ++i) { const auto & col = block.unsafeGetByPosition(i); - bool found = std::binary_search(args2.begin(), args2.end(), pos) && col.column && col.type; + bool found = std::binary_search(args.begin(), args.end(), i) && col.column && col.type; if (found && col.column.get()->isNullable()) { @@ -220,12 +220,10 @@ Block IFunction::extractNonNullableBlock(const Block & block, const ColumnNumber auto nullable_type = static_cast(col.type.get()); DataTypePtr nested_type = nullable_type->getNestedType(); - non_nullable_block.insert(pos, {nested_col, nested_type, col.name}); + non_nullable_block.insert(i, {nested_col, nested_type, col.name}); } else - non_nullable_block.insert(pos, col); - - ++pos; + non_nullable_block.insert(i, col); } return non_nullable_block; @@ -249,8 +247,8 @@ void IFunction::perform(Block & block, const ColumnNumbers & arguments, size_t r ColumnWithTypeAndName & dest_col = block.getByPosition(result); dest_col.column = std::make_shared(source_col.column); ColumnNullable & nullable_col = static_cast(*dest_col.column); - nullable_col.getNullValuesByteMap() = std::make_shared(dest_col.column->size()); - createNullValuesByteMap(block, result); + nullable_col.getNullValuesByteMap() = std::make_shared(dest_col.column->size(), 0); + createNullValuesByteMap(block, arguments, result); } else performer(block, arguments, result);