From b724f49deb7a8cab629f3c3af05aa5dae773e5c9 Mon Sep 17 00:00:00 2001 From: taiyang-li <654010905@qq.com> Date: Mon, 16 Sep 2024 11:24:15 +0800 Subject: [PATCH] fix failed uts --- src/Functions/FunctionHelpers.cpp | 102 ++++++++++++++++++------------ src/Functions/IFunction.cpp | 100 ++++++++++++++--------------- 2 files changed, 108 insertions(+), 94 deletions(-) diff --git a/src/Functions/FunctionHelpers.cpp b/src/Functions/FunctionHelpers.cpp index dde06a56357..c84b3275b57 100644 --- a/src/Functions/FunctionHelpers.cpp +++ b/src/Functions/FunctionHelpers.cpp @@ -1,10 +1,10 @@ +#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include -#include #include @@ -13,11 +13,11 @@ namespace DB namespace ErrorCodes { - extern const int ILLEGAL_COLUMN; - extern const int LOGICAL_ERROR; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; - extern const int SIZES_OF_ARRAYS_DONT_MATCH; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int ILLEGAL_COLUMN; +extern const int LOGICAL_ERROR; +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +extern const int SIZES_OF_ARRAYS_DONT_MATCH; +extern const int ILLEGAL_TYPE_OF_ARGUMENT; } const ColumnConst * checkAndGetColumnConstStringOrFixedString(const IColumn * column) @@ -27,8 +27,7 @@ const ColumnConst * checkAndGetColumnConstStringOrFixedString(const IColumn * co const ColumnConst * res = assert_cast(column); - if (checkColumn(&res->getDataColumn()) - || checkColumn(&res->getDataColumn())) + if (checkColumn(&res->getDataColumn()) || checkColumn(&res->getDataColumn())) return res; return {}; @@ -78,7 +77,7 @@ ColumnWithTypeAndName columnGetNested(const ColumnWithTypeAndName & col) { nullable_res = makeNullable(col.column); } - return ColumnWithTypeAndName{ nullable_res, nested_type, col.name }; + return ColumnWithTypeAndName{nullable_res, nested_type, col.name}; } else throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} for DataTypeNullable", col.dumpStructure()); @@ -102,18 +101,22 @@ String withOrdinalEnding(size_t i) { switch (i) { - case 0: return "1st"; - case 1: return "2nd"; - case 2: return "3rd"; - default: return std::to_string(i) + "th"; + case 0: + return "1st"; + case 1: + return "2nd"; + case 2: + return "3rd"; + default: + return std::to_string(i) + "th"; } - } -void validateArgumentsImpl(const IFunction & func, - const ColumnsWithTypeAndName & arguments, - size_t argument_offset, - const FunctionArgumentDescriptors & descriptors) +void validateArgumentsImpl( + const IFunction & func, + const ColumnsWithTypeAndName & arguments, + size_t argument_offset, + const FunctionArgumentDescriptors & descriptors) { for (size_t i = 0; i < descriptors.size(); ++i) { @@ -124,13 +127,14 @@ void validateArgumentsImpl(const IFunction & func, const auto & arg = arguments[i + argument_offset]; const auto & descriptor = descriptors[i]; if (int error_code = descriptor.isValid(arg.type, arg.column); error_code != 0) - throw Exception(error_code, - "A value of illegal type was provided as {} argument '{}' to function '{}'. Expected: {}, got: {}", - withOrdinalEnding(argument_offset + i), - descriptor.name, - func.getName(), - descriptor.type_name, - arg.type ? arg.type->getName() : ""); + throw Exception( + error_code, + "A value of illegal type was provided as {} argument '{}' to function '{}'. Expected: {}, got: {}", + withOrdinalEnding(argument_offset + i), + descriptor.name, + func.getName(), + descriptor.type_name, + arg.type ? arg.type->getName() : ""); } } @@ -150,26 +154,35 @@ int FunctionArgumentDescriptor::isValid(const DataTypePtr & data_type, const Col return 0; } -void validateFunctionArguments(const IFunction & func, - const ColumnsWithTypeAndName & arguments, - const FunctionArgumentDescriptors & mandatory_args, - const FunctionArgumentDescriptors & optional_args) +void validateFunctionArguments( + const IFunction & func, + const ColumnsWithTypeAndName & arguments, + const FunctionArgumentDescriptors & mandatory_args, + const FunctionArgumentDescriptors & optional_args) { if (arguments.size() < mandatory_args.size() || arguments.size() > mandatory_args.size() + optional_args.size()) { - auto argument_singular_or_plural = [](const auto & args) -> std::string_view { return args.size() == 1 ? "argument" : "arguments"; }; + auto argument_singular_or_plural + = [](const auto & args) -> std::string_view { return args.size() == 1 ? "argument" : "arguments"; }; String expected_args_string; if (!mandatory_args.empty() && !optional_args.empty()) - expected_args_string = fmt::format("{} mandatory {} and {} optional {}", mandatory_args.size(), argument_singular_or_plural(mandatory_args), optional_args.size(), argument_singular_or_plural(optional_args)); + expected_args_string = fmt::format( + "{} mandatory {} and {} optional {}", + mandatory_args.size(), + argument_singular_or_plural(mandatory_args), + optional_args.size(), + argument_singular_or_plural(optional_args)); else if (!mandatory_args.empty() && optional_args.empty()) - expected_args_string = fmt::format("{} {}", mandatory_args.size(), argument_singular_or_plural(mandatory_args)); /// intentionally not "_mandatory_ arguments" + expected_args_string = fmt::format( + "{} {}", mandatory_args.size(), argument_singular_or_plural(mandatory_args)); /// intentionally not "_mandatory_ arguments" else if (mandatory_args.empty() && !optional_args.empty()) expected_args_string = fmt::format("{} optional {}", optional_args.size(), argument_singular_or_plural(optional_args)); else expected_args_string = "0 arguments"; - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "An incorrect number of arguments was specified for function '{}'. Expected {}, got {}", func.getName(), expected_args_string, @@ -205,7 +218,8 @@ checkAndGetNestedArrayOffset(const IColumn ** columns, size_t num_arguments) return {nested_columns, offsets->data()}; } -ColumnPtr wrapInNullable(const ColumnPtr & src, const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count) +ColumnPtr +wrapInNullable(const ColumnPtr & src, const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count) { ColumnPtr result_null_map_column; @@ -269,10 +283,7 @@ ColumnPtr wrapInNullable(const ColumnPtr & src, const ColumnPtr & null_map) return src; ColumnPtr result_null_map_column; - - /// If result is already nullable. ColumnPtr src_not_nullable = src; - if (const auto * nullable = checkAndGetColumn(src.get())) { src_not_nullable = nullable->getNestedColumnPtr(); @@ -287,9 +298,16 @@ ColumnPtr wrapInNullable(const ColumnPtr & src, const ColumnPtr & null_map) result_null_map_column = std::move(mutable_result_null_map_column); return ColumnNullable::create(src_not_nullable->convertToFullColumnIfConst(), result_null_map_column); } + else if (const auto * const_src = checkAndGetColumn(src.get())) + { + const NullMap & null_map_data = assert_cast(*null_map).getData(); + ColumnPtr result_null_map = ColumnUInt8::create(1, null_map_data[0] || const_src->isNullAt(0)); + const auto * nullable_data = checkAndGetColumn(&const_src->getDataColumn()); + auto data_not_nullable = nullable_data ? nullable_data->getNestedColumnPtr() : const_src->getDataColumnPtr(); + return ColumnConst::create(ColumnNullable::create(data_not_nullable, result_null_map), const_src->size()); + } else return ColumnNullable::create(src->convertToFullColumnIfConst(), null_map); - } NullPresence getNullPresense(const ColumnsWithTypeAndName & args) diff --git a/src/Functions/IFunction.cpp b/src/Functions/IFunction.cpp index 0b3c832add3..07892254636 100644 --- a/src/Functions/IFunction.cpp +++ b/src/Functions/IFunction.cpp @@ -1,28 +1,28 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include #include -#include -#include #include -#include #include +#include +#include +#include +#include #include +#include +#include +#include +#include #include #include #include -#include #include -#include -#include +#include +#include +#include +#include +#include #include "config.h" @@ -36,9 +36,9 @@ namespace DB namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; - extern const int ILLEGAL_COLUMN; +extern const int LOGICAL_ERROR; +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +extern const int ILLEGAL_COLUMN; } namespace @@ -70,9 +70,7 @@ ColumnPtr replaceLowCardinalityColumnsByNestedAndGetDictionaryIndexes( const auto * low_cardinality_type = checkAndGetDataType(column.type.get()); if (!low_cardinality_type) - throw Exception(ErrorCodes::LOGICAL_ERROR, - "Incompatible type for LowCardinality column: {}", - column.type->getName()); + throw Exception(ErrorCodes::LOGICAL_ERROR, "Incompatible type for LowCardinality column: {}", column.type->getName()); if (can_be_executed_on_default_arguments) { @@ -125,10 +123,7 @@ ColumnPtr IExecutableFunction::defaultImplementationForConstantArguments( /// Check that these arguments are really constant. for (auto arg_num : arguments_to_remain_constants) if (arg_num < args.size() && !isColumnConst(*args[arg_num].column)) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Argument at index {} for function {} must be constant", - arg_num, - getName()); + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Argument at index {} for function {} must be constant", arg_num, getName()); if (args.empty() || !useDefaultImplementationForConstants() || !allArgumentsAreConstants(args)) return nullptr; @@ -142,14 +137,16 @@ ColumnPtr IExecutableFunction::defaultImplementationForConstantArguments( { const ColumnWithTypeAndName & column = args[arg_num]; - if (arguments_to_remain_constants.end() != std::find(arguments_to_remain_constants.begin(), arguments_to_remain_constants.end(), arg_num)) + if (arguments_to_remain_constants.end() + != std::find(arguments_to_remain_constants.begin(), arguments_to_remain_constants.end(), arg_num)) { temporary_columns.emplace_back(ColumnWithTypeAndName{column.column->cloneResized(1), column.type, column.name}); } else { have_converted_columns = true; - temporary_columns.emplace_back(ColumnWithTypeAndName{ assert_cast(column.column.get())->getDataColumnPtr(), column.type, column.name }); + temporary_columns.emplace_back( + ColumnWithTypeAndName{assert_cast(column.column.get())->getDataColumnPtr(), column.type, column.name}); } } @@ -157,7 +154,8 @@ ColumnPtr IExecutableFunction::defaultImplementationForConstantArguments( * not in "arguments_to_remain_constants" set. Otherwise we get infinite recursion. */ if (!have_converted_columns) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Number of arguments for function {} doesn't match: the function requires more arguments", getName()); @@ -194,9 +192,7 @@ ColumnPtr IExecutableFunction::defaultImplementationForNulls( } if (null_presence.has_null_constant) - { return result_type->createColumnConstWithDefaultValue(input_rows_count); - } if (null_presence.has_nullable) { @@ -212,8 +208,7 @@ ColumnPtr IExecutableFunction::defaultImplementationForNulls( { if (isColumnConst(*arg.column)) { - const auto & const_col = assert_cast(*arg.column); - if (const_col.isNullAt(0)) + if (arg.column->isNullAt(0)) { mask_info.has_ones = false; mask_info.has_zeros = true; @@ -233,7 +228,7 @@ ColumnPtr IExecutableFunction::defaultImplementationForNulls( if (!mask_info.has_ones) { - /// Do not actually execute function if each row contains at least one null value. + /// Don't need to evaluate function if each row contains at least one null value. return result_type->createColumnConstWithDefaultValue(input_rows_count); } else if (!mask_info.has_zeros || !short_circuit_default_implementation_for_nulls) @@ -249,11 +244,12 @@ ColumnPtr IExecutableFunction::defaultImplementationForNulls( auto null_map = ColumnUInt8::create(); null_map->getData() = std::move(mask); - return wrapInNullable(res, std::move(null_map)); + auto new_res = wrapInNullable(res, std::move(null_map)); + return new_res; } else { - /// If short circuiting is enabled, only execute the function on rows with all arguments not null + /// If short circuit is enabled, we only execute the function on rows with all arguments not null ColumnsWithTypeAndName temporary_columns = createBlockWithNestedColumns(args); auto temporary_result_type = removeNullable(result_type); @@ -262,7 +258,7 @@ ColumnPtr IExecutableFunction::defaultImplementationForNulls( for (auto & col : temporary_columns) col.column = col.column->filter(mask, size_hint); - auto res = executeWithoutLowCardinalityColumns(temporary_columns, temporary_result_type, input_rows_count, dry_run); + auto res = executeWithoutLowCardinalityColumns(temporary_columns, temporary_result_type, size_hint, dry_run); auto mutable_res = IColumn::mutate(std::move(res)); mutable_res->expand(mask, false); @@ -271,7 +267,8 @@ ColumnPtr IExecutableFunction::defaultImplementationForNulls( auto null_map = ColumnUInt8::create(); null_map->getData() = std::move(mask); - return wrapInNullable(std::move(mutable_res), std::move(null_map)); + auto new_res = wrapInNullable(std::move(mutable_res), std::move(null_map)); + return new_res; } } @@ -344,7 +341,8 @@ IExecutableFunction::IExecutableFunction() } } -ColumnPtr IExecutableFunction::executeWithoutSparseColumns(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const +ColumnPtr IExecutableFunction::executeWithoutSparseColumns( + const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const { ColumnPtr result; if (useDefaultImplementationForLowCardinalityColumns()) @@ -357,19 +355,16 @@ ColumnPtr IExecutableFunction::executeWithoutSparseColumns(const ColumnsWithType const auto & dictionary_type = res_low_cardinality_type->getDictionaryType(); ColumnPtr indexes = replaceLowCardinalityColumnsByNestedAndGetDictionaryIndexes( - columns_without_low_cardinality, can_be_executed_on_default_arguments, input_rows_count); + columns_without_low_cardinality, can_be_executed_on_default_arguments, input_rows_count); - size_t new_input_rows_count = columns_without_low_cardinality.empty() - ? input_rows_count - : columns_without_low_cardinality.front().column->size(); + size_t new_input_rows_count + = columns_without_low_cardinality.empty() ? input_rows_count : columns_without_low_cardinality.front().column->size(); checkFunctionArgumentSizes(columns_without_low_cardinality, new_input_rows_count); auto res = executeWithoutLowCardinalityColumns(columns_without_low_cardinality, dictionary_type, new_input_rows_count, dry_run); bool res_is_constant = isColumnConst(*res); - auto keys = res_is_constant - ? res->cloneResized(1)->convertToFullColumnIfConst() - : res; + auto keys = res_is_constant ? res->cloneResized(1)->convertToFullColumnIfConst() : res; auto res_mut_dictionary = DataTypeLowCardinality::createColumnUnique(*res_low_cardinality_type->getDictionaryType()); ColumnPtr res_indexes = res_mut_dictionary->uniqueInsertRangeFrom(*keys, 0, keys->size()); @@ -395,7 +390,8 @@ ColumnPtr IExecutableFunction::executeWithoutSparseColumns(const ColumnsWithType return result; } -ColumnPtr IExecutableFunction::execute(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const +ColumnPtr IExecutableFunction::execute( + const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const { checkFunctionArgumentSizes(arguments, input_rows_count); @@ -456,7 +452,7 @@ ColumnPtr IExecutableFunction::execute(const ColumnsWithTypeAndName & arguments, if (!result_type->canBeInsideSparseColumns() || !res->isDefaultAt(0) || res->getNumberOfDefaultRows() != 1) { const auto & offsets_data = assert_cast &>(*sparse_offsets).getData(); - return res->createWithOffsets(offsets_data, *createColumnConst(res, 0), input_rows_count, /*shift=*/ 1); + return res->createWithOffsets(offsets_data, *createColumnConst(res, 0), input_rows_count, /*shift=*/1); } return ColumnSparse::create(res, sparse_offsets, input_rows_count); @@ -483,7 +479,8 @@ void IFunctionOverloadResolver::checkNumberOfArguments(size_t number_of_argument size_t expected_number_of_arguments = getNumberOfArguments(); if (number_of_arguments != expected_number_of_arguments) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Number of arguments for function {} doesn't match: passed {}, should be {}", getName(), number_of_arguments, @@ -522,9 +519,8 @@ DataTypePtr IFunctionOverloadResolver::getReturnType(const ColumnsWithTypeAndNam auto type_without_low_cardinality = getReturnTypeWithoutLowCardinality(args_without_low_cardinality); - if (canBeExecutedOnLowCardinalityDictionary() && has_low_cardinality - && num_full_low_cardinality_columns <= 1 && num_full_ordinary_columns == 0 - && type_without_low_cardinality->canBeInsideLowCardinality()) + if (canBeExecutedOnLowCardinalityDictionary() && has_low_cardinality && num_full_low_cardinality_columns <= 1 + && num_full_ordinary_columns == 0 && type_without_low_cardinality->canBeInsideLowCardinality()) return std::make_shared(type_without_low_cardinality); else return type_without_low_cardinality; @@ -631,7 +627,7 @@ llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const ValuesWith ValuesWithType unwrapped_arguments; unwrapped_arguments.reserve(arguments.size()); - std::vector is_null_values; + std::vector is_null_values; for (size_t i = 0; i < arguments.size(); ++i) {