#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NOT_IMPLEMENTED; extern const int LOGICAL_ERROR; } namespace { using namespace GatherUtils; /** Selection function by condition: if(cond, then, else). * cond - UInt8 * then, else - numeric types for which there is a general type, or dates, datetimes, or strings, or arrays of these types. * For better performance, try to use branch free code for numeric types(i.e. cond ? a : b --> !!cond * a + !cond * b) */ template concept is_native_int_or_decimal_v = std::is_integral_v || (is_decimal && sizeof(ResultType) <= 8); // This macro performs a branch-free conditional assignment for floating point types. // It uses bitwise operations to avoid branching, which can be beneficial for performance. #define BRANCHFREE_IF_FLOAT(TYPE, vc, va, vb, vr) \ using UIntType = typename NumberTraits::Construct::Type; \ using IntType = typename NumberTraits::Construct::Type; \ auto mask = static_cast(static_cast(vc) - 1); \ auto new_a = static_cast(va); \ auto new_b = static_cast(vb); \ UIntType uint_a; \ std::memcpy(&uint_a, &new_a, sizeof(UIntType)); \ UIntType uint_b; \ std::memcpy(&uint_b, &new_b, sizeof(UIntType)); \ UIntType tmp = (~mask & uint_a) | (mask & uint_b); \ (vr) = *(reinterpret_cast(&tmp)); template inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const ArrayB & b, ArrayResult & res) { size_t size = cond.size(); bool a_is_short = a.size() < size; bool b_is_short = b.size() < size; if (a_is_short && b_is_short) { size_t a_index = 0, b_index = 0; for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[a_index]) + (!cond[i]) * static_cast(b[b_index]); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[b_index], res[i]) } else res[i] = cond[i] ? static_cast(a[a_index]) : static_cast(b[b_index]); a_index += !!cond[i]; b_index += !cond[i]; } } else if (a_is_short) { size_t a_index = 0; for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[a_index]) + (!cond[i]) * static_cast(b[i]); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[i], res[i]) } else res[i] = cond[i] ? static_cast(a[a_index]) : static_cast(b[i]); a_index += !!cond[i]; } } else if (b_is_short) { size_t b_index = 0; for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[i]) + (!cond[i]) * static_cast(b[b_index]); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[b_index], res[i]) } else res[i] = cond[i] ? static_cast(a[i]) : static_cast(b[b_index]); b_index += !cond[i]; } } else { for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[i]) + (!cond[i]) * static_cast(b[i]); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[i], res[i]) } else { res[i] = cond[i] ? static_cast(a[i]) : static_cast(b[i]); } } } } template inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, ArrayResult & res) { size_t size = cond.size(); bool a_is_short = a.size() < size; if (a_is_short) { size_t a_index = 0; for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[a_index]) + (!cond[i]) * static_cast(b); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b, res[i]) } else res[i] = cond[i] ? static_cast(a[a_index]) : static_cast(b); a_index += !!cond[i]; } } else { for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[i]) + (!cond[i]) * static_cast(b); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b, res[i]) } else res[i] = cond[i] ? static_cast(a[i]) : static_cast(b); } } } template inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, ArrayResult & res) { size_t size = cond.size(); bool b_is_short = b.size() < size; if (b_is_short) { size_t b_index = 0; for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a) + (!cond[i]) * static_cast(b[b_index]); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[b_index], res[i]) } else res[i] = cond[i] ? static_cast(a) : static_cast(b[b_index]); b_index += !cond[i]; } } else { for (size_t i = 0; i < size; ++i) { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a) + (!cond[i]) * static_cast(b[i]); else if constexpr (std::is_floating_point_v) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[i], res[i]) } else res[i] = cond[i] ? static_cast(a) : static_cast(b[i]); } } } template inline void fillConstantConstant(const ArrayCond & cond, A a, B b, ArrayResult & res) { size_t size = cond.size(); /// We manually optimize the loop for types like (U)Int128|256 or Decimal128/256 to avoid branches if constexpr (is_over_big_int) { alignas(64) const ResultType ab[2] = {static_cast(a), static_cast(b)}; for (size_t i = 0; i < size; ++i) { res[i] = ab[!cond[i]]; } } else if constexpr (std::is_same_v || std::is_same_v) { ResultType new_a = static_cast(a); ResultType new_b = static_cast(b); for (size_t i = 0; i < size; ++i) { /// Reuse new_a and new_b to achieve auto-vectorization res[i] = cond[i] ? new_a : new_b; } } else { for (size_t i = 0; i < size; ++i) res[i] = cond[i] ? static_cast(a) : static_cast(b); } } template struct NumIfImpl { using ArrayCond = PaddedPODArray; using ArrayA = typename ColumnVector::Container; using ArrayB = typename ColumnVector::Container; using ColVecResult = ColumnVector; using ArrayResult = typename ColVecResult::Container; static ColumnPtr vectorVector(const ArrayCond & cond, const ArrayA & a, const ArrayB & b, UInt32) { size_t size = cond.size(); auto col_res = ColVecResult::create(size); ArrayResult & res = col_res->getData(); fillVectorVector(cond, a, b, res); return col_res; } static ColumnPtr vectorConstant(const ArrayCond & cond, const ArrayA & a, B b, UInt32) { size_t size = cond.size(); auto col_res = ColVecResult::create(size); ArrayResult & res = col_res->getData(); fillVectorConstant(cond, a, b, res); return col_res; } static ColumnPtr constantVector(const ArrayCond & cond, A a, const ArrayB & b, UInt32) { size_t size = cond.size(); auto col_res = ColVecResult::create(size); ArrayResult & res = col_res->getData(); fillConstantVector(cond, a, b, res); return col_res; } static ColumnPtr constantConstant(const ArrayCond & cond, A a, B b, UInt32) { size_t size = cond.size(); auto col_res = ColVecResult::create(size); ArrayResult & res = col_res->getData(); fillConstantConstant(cond, a, b, res); return col_res; } }; template struct NumIfImpl, Decimal, Decimal> { using ResultType = Decimal; using ArrayCond = PaddedPODArray; using ArrayA = typename ColumnDecimal>::Container; using ArrayB = typename ColumnDecimal>::Container; using ColVecResult = ColumnDecimal; using Block = ColumnsWithTypeAndName; using ArrayResult = typename ColVecResult::Container; static ColumnPtr vectorVector(const ArrayCond & cond, const ArrayA & a, const ArrayB & b, UInt32 scale) { size_t size = cond.size(); auto col_res = ColVecResult::create(size, scale); ArrayResult & res = col_res->getData(); fillVectorVector(cond, a, b, res); return col_res; } static ColumnPtr vectorConstant(const ArrayCond & cond, const ArrayA & a, B b, UInt32 scale) { size_t size = cond.size(); auto col_res = ColVecResult::create(size, scale); ArrayResult & res = col_res->getData(); fillVectorConstant(cond, a, b, res); return col_res; } static ColumnPtr constantVector(const ArrayCond & cond, A a, const ArrayB & b, UInt32 scale) { size_t size = cond.size(); auto col_res = ColVecResult::create(size, scale); ArrayResult & res = col_res->getData(); fillConstantVector(cond, a, b, res); return col_res; } static ColumnPtr constantConstant(const ArrayCond & cond, A a, B b, UInt32 scale) { size_t size = cond.size(); auto col_res = ColVecResult::create(size, scale); ArrayResult & res = col_res->getData(); fillConstantConstant(cond, a, b, res); return col_res; } }; class FunctionIf : public FunctionIfBase { public: static constexpr auto name = "if"; static FunctionPtr create(ContextPtr context) { return std::make_shared(context->getSettingsRef().allow_experimental_variant_type && context->getSettingsRef().use_variant_as_common_type); } explicit FunctionIf(bool use_variant_when_no_common_type_ = false) : FunctionIfBase(), use_variant_when_no_common_type(use_variant_when_no_common_type_) {} private: bool use_variant_when_no_common_type = false; template static UInt32 decimalScale(const ColumnsWithTypeAndName & arguments [[maybe_unused]]) { if constexpr (is_decimal && is_decimal) { UInt32 left_scale = getDecimalScale(*arguments[1].type); UInt32 right_scale = getDecimalScale(*arguments[2].type); if (left_scale != right_scale) throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conditional functions with different Decimal scales"); return left_scale; } else return std::numeric_limits::max(); } template ColumnPtr executeRightType( [[maybe_unused]] const ColumnUInt8 * cond_col, [[maybe_unused]] const ColumnsWithTypeAndName & arguments, [[maybe_unused]] const ColVecT0 * col_left) const { using ResultType = typename NumberTraits::ResultOfIf::Type; if constexpr (std::is_same_v) { return nullptr; } else { const IColumn * col_right_untyped = arguments[2].column.get(); UInt32 scale = decimalScale(arguments); if (const auto * col_right_vec = checkAndGetColumn(col_right_untyped)) { return NumIfImpl::vectorVector( cond_col->getData(), col_left->getData(), col_right_vec->getData(), scale); } else if (const auto * col_right_const = checkAndGetColumnConst(col_right_untyped)) { return NumIfImpl::vectorConstant( cond_col->getData(), col_left->getData(), col_right_const->template getValue(), scale); } return nullptr; } } template ColumnPtr executeConstRightType( [[maybe_unused]] const ColumnUInt8 * cond_col, [[maybe_unused]] const ColumnsWithTypeAndName & arguments, [[maybe_unused]] const ColumnConst * col_left) const { using ResultType = typename NumberTraits::ResultOfIf::Type; if constexpr (std::is_same_v) { return nullptr; } else { const IColumn * col_right_untyped = arguments[2].column.get(); UInt32 scale = decimalScale(arguments); if (const auto * col_right_vec = checkAndGetColumn(col_right_untyped)) { return NumIfImpl::constantVector( cond_col->getData(), col_left->template getValue(), col_right_vec->getData(), scale); } else if (const auto * col_right_const = checkAndGetColumnConst(col_right_untyped)) { return NumIfImpl::constantConstant( cond_col->getData(), col_left->template getValue(), col_right_const->template getValue(), scale); } return nullptr; } } template ColumnPtr executeRightTypeArray( [[maybe_unused]] const ColumnUInt8 * cond_col, [[maybe_unused]] const ColumnsWithTypeAndName & arguments, [[maybe_unused]] const DataTypePtr result_type, [[maybe_unused]] const ColumnArray * col_left_array, [[maybe_unused]] size_t input_rows_count) const { using ResultType = typename NumberTraits::ResultOfIf::Type; if constexpr (std::is_same_v) { return nullptr; } else { const IColumn * col_right_untyped = arguments[2].column.get(); if (const auto * col_right_array = checkAndGetColumn(col_right_untyped)) { const ColVecT1 * col_right_vec = checkAndGetColumn(&col_right_array->getData()); if (!col_right_vec) return nullptr; auto res = result_type->createColumn(); auto & arr_res = assert_cast(*res); conditional( NumericArraySource(*col_left_array), NumericArraySource(*col_right_array), NumericArraySink(arr_res.getData(), arr_res.getOffsets(), input_rows_count), cond_col->getData()); return res; } else if (const auto * col_right_const_array = checkAndGetColumnConst(col_right_untyped)) { const ColumnArray * col_right_const_array_data = checkAndGetColumn(&col_right_const_array->getDataColumn()); if (!checkColumn(&col_right_const_array_data->getData())) return nullptr; auto res = result_type->createColumn(); auto & arr_res = assert_cast(*res); conditional( NumericArraySource(*col_left_array), ConstSource>(*col_right_const_array), NumericArraySink(arr_res.getData(), arr_res.getOffsets(), input_rows_count), cond_col->getData()); return res; } return nullptr; } } template ColumnPtr executeConstRightTypeArray( [[maybe_unused]] const ColumnUInt8 * cond_col, [[maybe_unused]] const ColumnsWithTypeAndName & arguments, [[maybe_unused]] const DataTypePtr & result_type, [[maybe_unused]] const ColumnConst * col_left_const_array, [[maybe_unused]] size_t input_rows_count) const { using ResultType = typename NumberTraits::ResultOfIf::Type; if constexpr (std::is_same_v) { return nullptr; } else { const IColumn * col_right_untyped = arguments[2].column.get(); if (const auto * col_right_array = checkAndGetColumn(col_right_untyped)) { const ColVecT1 * col_right_vec = checkAndGetColumn(&col_right_array->getData()); if (!col_right_vec) return nullptr; auto res = result_type->createColumn(); auto & arr_res = assert_cast(*res); conditional( ConstSource>(*col_left_const_array), NumericArraySource(*col_right_array), NumericArraySink(arr_res.getData(), arr_res.getOffsets(), input_rows_count), cond_col->getData()); return res; } else if (const auto * col_right_const_array = checkAndGetColumnConst(col_right_untyped)) { const ColumnArray * col_right_const_array_data = checkAndGetColumn(&col_right_const_array->getDataColumn()); if (!checkColumn(&col_right_const_array_data->getData())) return nullptr; auto res = result_type->createColumn(); auto & arr_res = assert_cast(*res); conditional( ConstSource>(*col_left_const_array), ConstSource>(*col_right_const_array), NumericArraySink(arr_res.getData(), arr_res.getOffsets(), input_rows_count), cond_col->getData()); return res; } return nullptr; } } template ColumnPtr executeTyped( const ColumnUInt8 * cond_col, const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { using ColVecT0 = ColumnVectorOrDecimal; using ColVecT1 = ColumnVectorOrDecimal; const IColumn * col_left_untyped = arguments[1].column.get(); ColumnPtr right_column = nullptr; if (const auto * col_left = checkAndGetColumn(col_left_untyped)) { right_column = executeRightType(cond_col, arguments, col_left); } else if (const auto * col_const_left = checkAndGetColumnConst(col_left_untyped)) { right_column = executeConstRightType(cond_col, arguments, col_const_left); } else if (const auto * col_arr_left = checkAndGetColumn(col_left_untyped)) { if (auto col_arr_left_elems = checkAndGetColumn(&col_arr_left->getData())) { right_column = executeRightTypeArray( cond_col, arguments, result_type, col_arr_left, input_rows_count); } } else if (const auto * col_const_arr_left = checkAndGetColumnConst(col_left_untyped)) { if (checkColumn(&assert_cast(col_const_arr_left->getDataColumn()).getData())) { right_column = executeConstRightTypeArray( cond_col, arguments, result_type, col_const_arr_left, input_rows_count); } } return right_column; } static ColumnPtr executeString(const ColumnUInt8 * cond_col, const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) { const IColumn * col_then_untyped = arguments[1].column.get(); const IColumn * col_else_untyped = arguments[2].column.get(); const ColumnString * col_then = checkAndGetColumn(col_then_untyped); const ColumnString * col_else = checkAndGetColumn(col_else_untyped); const ColumnFixedString * col_then_fixed = checkAndGetColumn(col_then_untyped); const ColumnFixedString * col_else_fixed = checkAndGetColumn(col_else_untyped); const ColumnConst * col_then_const = checkAndGetColumnConst(col_then_untyped); const ColumnConst * col_else_const = checkAndGetColumnConst(col_else_untyped); const ColumnConst * col_then_const_fixed = checkAndGetColumnConst(col_then_untyped); const ColumnConst * col_else_const_fixed = checkAndGetColumnConst(col_else_untyped); const PaddedPODArray & cond_data = cond_col->getData(); size_t rows = cond_data.size(); if (isFixedString(result_type)) { /// The result is FixedString. auto col_res_untyped = result_type->createColumn(); ColumnFixedString * col_res = assert_cast(col_res_untyped.get()); auto sink = FixedStringSink(*col_res, rows); if (col_then_fixed && col_else_fixed) conditional(FixedStringSource(*col_then_fixed), FixedStringSource(*col_else_fixed), sink, cond_data); else if (col_then_fixed && col_else_const_fixed) conditional(FixedStringSource(*col_then_fixed), ConstSource(*col_else_const_fixed), sink, cond_data); else if (col_then_const_fixed && col_else_fixed) conditional(ConstSource(*col_then_const_fixed), FixedStringSource(*col_else_fixed), sink, cond_data); else if (col_then_const_fixed && col_else_const_fixed) conditional(ConstSource(*col_then_const_fixed), ConstSource(*col_else_const_fixed), sink, cond_data); else return nullptr; return col_res_untyped; } if (isString(result_type)) { /// The result is String. auto col_res = ColumnString::create(); auto sink = StringSink(*col_res, rows); if (col_then && col_else) conditional(StringSource(*col_then), StringSource(*col_else), sink, cond_data); else if (col_then && col_else_const) conditional(StringSource(*col_then), ConstSource(*col_else_const), sink, cond_data); else if (col_then_const && col_else) conditional(ConstSource(*col_then_const), StringSource(*col_else), sink, cond_data); else if (col_then_const && col_else_const) conditional(ConstSource(*col_then_const), ConstSource(*col_else_const), sink, cond_data); else if (col_then && col_else_fixed) conditional(StringSource(*col_then), FixedStringSource(*col_else_fixed), sink, cond_data); else if (col_then_fixed && col_else) conditional(FixedStringSource(*col_then_fixed), StringSource(*col_else), sink, cond_data); else if (col_then_const && col_else_fixed) conditional(ConstSource(*col_then_const), FixedStringSource(*col_else_fixed), sink, cond_data); else if (col_then_fixed && col_else_const) conditional(FixedStringSource(*col_then_fixed), ConstSource(*col_else_const), sink, cond_data); else if (col_then && col_else_const_fixed) conditional(StringSource(*col_then), ConstSource(*col_else_const_fixed), sink, cond_data); else if (col_then_const_fixed && col_else) conditional(ConstSource(*col_then_const_fixed), StringSource(*col_else), sink, cond_data); else if (col_then_const && col_else_const_fixed) conditional(ConstSource(*col_then_const), ConstSource(*col_else_const_fixed), sink, cond_data); else if (col_then_const_fixed && col_else_const) conditional(ConstSource(*col_then_const_fixed), ConstSource(*col_else_const), sink, cond_data); else if (col_then_fixed && col_else_fixed) conditional(FixedStringSource(*col_then_fixed), FixedStringSource(*col_else_fixed), sink, cond_data); else if (col_then_fixed && col_else_const_fixed) conditional(FixedStringSource(*col_then_fixed), ConstSource(*col_else_const_fixed), sink, cond_data); else if (col_then_const_fixed && col_else_fixed) conditional(ConstSource(*col_then_const_fixed), FixedStringSource(*col_else_fixed), sink, cond_data); else if (col_then_const_fixed && col_else_const_fixed) conditional(ConstSource(*col_then_const_fixed), ConstSource(*col_else_const_fixed), sink, cond_data); else return nullptr; return col_res; } return nullptr; } static ColumnPtr executeGenericArray(const ColumnUInt8 * cond_col, const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) { /// For generic implementation, arrays must be of same type. if (!arguments[1].type->equals(*arguments[2].type)) return nullptr; const IColumn * col_then_untyped = arguments[1].column.get(); const IColumn * col_else_untyped = arguments[2].column.get(); const ColumnArray * col_arr_then = checkAndGetColumn(col_then_untyped); const ColumnArray * col_arr_else = checkAndGetColumn(col_else_untyped); const ColumnConst * col_arr_then_const = checkAndGetColumnConst(col_then_untyped); const ColumnConst * col_arr_else_const = checkAndGetColumnConst(col_else_untyped); const PaddedPODArray & cond_data = cond_col->getData(); size_t rows = cond_data.size(); if ((col_arr_then || col_arr_then_const) && (col_arr_else || col_arr_else_const)) { auto res = result_type->createColumn(); auto * col_res = assert_cast(res.get()); if (col_arr_then && col_arr_else) conditional(GenericArraySource(*col_arr_then), GenericArraySource(*col_arr_else), GenericArraySink(col_res->getData(), col_res->getOffsets(), rows), cond_data); else if (col_arr_then && col_arr_else_const) conditional(GenericArraySource(*col_arr_then), ConstSource(*col_arr_else_const), GenericArraySink(col_res->getData(), col_res->getOffsets(), rows), cond_data); else if (col_arr_then_const && col_arr_else) conditional(ConstSource(*col_arr_then_const), GenericArraySource(*col_arr_else), GenericArraySink(col_res->getData(), col_res->getOffsets(), rows), cond_data); else if (col_arr_then_const && col_arr_else_const) conditional(ConstSource(*col_arr_then_const), ConstSource(*col_arr_else_const), GenericArraySink(col_res->getData(), col_res->getOffsets(), rows), cond_data); else return nullptr; return res; } return nullptr; } ColumnPtr executeTuple(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { /// Calculate function for each corresponding elements of tuples. const ColumnWithTypeAndName & arg1 = arguments[1]; const ColumnWithTypeAndName & arg2 = arguments[2]; Columns col1_contents; Columns col2_contents; if (const ColumnTuple * tuple1 = typeid_cast(arg1.column.get())) col1_contents = tuple1->getColumnsCopy(); else if (const ColumnConst * const_tuple = checkAndGetColumnConst(arg1.column.get())) col1_contents = convertConstTupleToConstantElements(*const_tuple); else return nullptr; if (const ColumnTuple * tuple2 = typeid_cast(arg2.column.get())) col2_contents = tuple2->getColumnsCopy(); else if (const ColumnConst * const_tuple = checkAndGetColumnConst(arg2.column.get())) col2_contents = convertConstTupleToConstantElements(*const_tuple); else return nullptr; const DataTypeTuple & type1 = static_cast(*arg1.type); const DataTypeTuple & type2 = static_cast(*arg2.type); const DataTypeTuple & tuple_result = static_cast(*result_type); ColumnsWithTypeAndName temporary_columns(3); temporary_columns[0] = arguments[0]; size_t tuple_size = type1.getElements().size(); Columns tuple_columns(tuple_size); for (size_t i = 0; i < tuple_size; ++i) { temporary_columns[1] = {col1_contents[i], type1.getElements()[i], {}}; temporary_columns[2] = {col2_contents[i], type2.getElements()[i], {}}; tuple_columns[i] = executeImpl(temporary_columns, tuple_result.getElements()[i], input_rows_count); } return ColumnTuple::create(tuple_columns); } ColumnPtr executeMap(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { auto extract_kv_from_map = [](const ColumnMap * map) { const ColumnTuple & tuple = map->getNestedData(); const auto & keys = tuple.getColumnPtr(0); const auto & values = tuple.getColumnPtr(1); const auto & offsets = map->getNestedColumn().getOffsetsPtr(); return std::make_pair(ColumnArray::create(keys, offsets), ColumnArray::create(values, offsets)); }; /// Extract keys and values from both arguments Columns key_cols(2); Columns value_cols(2); for (size_t i = 0; i < 2; ++i) { const auto & arg = arguments[i + 1]; if (const ColumnMap * map = checkAndGetColumn(arg.column.get())) { auto [key_col, value_col] = extract_kv_from_map(map); key_cols[i] = std::move(key_col); value_cols[i] = std::move(value_col); } else if (const ColumnConst * const_map = checkAndGetColumnConst(arg.column.get())) { const ColumnMap * map_data = assert_cast(&const_map->getDataColumn()); auto [key_col, value_col] = extract_kv_from_map(map_data); size_t size = const_map->size(); key_cols[i] = ColumnConst::create(std::move(key_col), size); value_cols[i] = ColumnConst::create(std::move(value_col), size); } else return nullptr; } /// Compose temporary columns for keys and values ColumnsWithTypeAndName key_columns(3); key_columns[0] = arguments[0]; ColumnsWithTypeAndName value_columns(3); value_columns[0] = arguments[0]; for (size_t i = 0; i < 2; ++i) { const auto & arg = arguments[i + 1]; const DataTypeMap & type = static_cast(*arg.type); const auto & key_type = type.getKeyType(); const auto & value_type = type.getValueType(); key_columns[i + 1] = {key_cols[i], std::make_shared(key_type), {}}; value_columns[i + 1] = {value_cols[i], std::make_shared(value_type), {}}; } /// Calculate function corresponding keys and values in map const DataTypeMap & map_result_type = static_cast(*result_type); auto key_result_type = std::make_shared(map_result_type.getKeyType()); auto value_result_type = std::make_shared(map_result_type.getValueType()); ColumnPtr key_result = executeImpl(key_columns, key_result_type, input_rows_count); ColumnPtr value_result = executeImpl(value_columns, value_result_type, input_rows_count); /// key_result and value_result are not constant columns otherwise we won't reach here in executeMap const auto * key_array = assert_cast(key_result.get()); const auto * value_array = assert_cast(value_result.get()); if (!key_array) throw Exception( ErrorCodes::LOGICAL_ERROR, "Key result column should be {} instead of {} in executeMap of function {}", key_result_type->getName(), key_result->getName(), getName()); if (!value_array) throw Exception( ErrorCodes::LOGICAL_ERROR, "Value result column should be {} instead of {} in executeMap of function {}", key_result_type->getName(), value_result->getName(), getName()); if (!key_array->hasEqualOffsets(*value_array)) throw Exception( ErrorCodes::LOGICAL_ERROR, "Key array and value array must have equal sizes in executeMap of function {}", getName()); auto nested_column = ColumnArray::create( ColumnTuple::create(Columns{key_array->getDataPtr(), value_array->getDataPtr()}), key_array->getOffsetsPtr()); return ColumnMap::create(std::move(nested_column)); } static ColumnPtr executeGeneric( const ColumnUInt8 * cond_col, const ColumnsWithTypeAndName & arguments, size_t input_rows_count, bool use_variant_when_no_common_type) { /// Convert both columns to the common type (if needed). const ColumnWithTypeAndName & arg1 = arguments[1]; const ColumnWithTypeAndName & arg2 = arguments[2]; DataTypePtr common_type; if (use_variant_when_no_common_type) common_type = getLeastSupertypeOrVariant(DataTypes{arg1.type, arg2.type}); else common_type = getLeastSupertype(DataTypes{arg1.type, arg2.type}); ColumnPtr col_then = castColumn(arg1, common_type); ColumnPtr col_else = castColumn(arg2, common_type); MutableColumnPtr result_column = common_type->createColumn(); result_column->reserve(input_rows_count); bool then_is_const = isColumnConst(*col_then); bool else_is_const = isColumnConst(*col_else); bool then_is_short = col_then->size() < cond_col->size(); bool else_is_short = col_else->size() < cond_col->size(); const auto & cond_array = cond_col->getData(); if (then_is_const && else_is_const) { const IColumn & then_nested_column = assert_cast(*col_then).getDataColumn(); const IColumn & else_nested_column = assert_cast(*col_else).getDataColumn(); for (size_t i = 0; i < input_rows_count; ++i) { if (cond_array[i]) result_column->insertFrom(then_nested_column, 0); else result_column->insertFrom(else_nested_column, 0); } } else if (then_is_const) { const IColumn & then_nested_column = assert_cast(*col_then).getDataColumn(); size_t else_index = 0; for (size_t i = 0; i < input_rows_count; ++i) { if (cond_array[i]) result_column->insertFrom(then_nested_column, 0); else result_column->insertFrom(*col_else, else_is_short ? else_index++ : i); } } else if (else_is_const) { const IColumn & else_nested_column = assert_cast(*col_else).getDataColumn(); size_t then_index = 0; for (size_t i = 0; i < input_rows_count; ++i) { if (cond_array[i]) result_column->insertFrom(*col_then, then_is_short ? then_index++ : i); else result_column->insertFrom(else_nested_column, 0); } } else { size_t then_index = 0, else_index = 0; for (size_t i = 0; i < input_rows_count; ++i) { if (cond_array[i]) result_column->insertFrom(*col_then, then_is_short ? then_index++ : i); else result_column->insertFrom(*col_else, else_is_short ? else_index++ : i); } } return result_column; } ColumnPtr executeForConstAndNullableCondition( const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const { const ColumnWithTypeAndName & arg_cond = arguments[0]; bool cond_is_null = arg_cond.column->onlyNull(); ColumnPtr not_const_condition = arg_cond.column; bool cond_is_const = false; bool cond_is_true = false; bool cond_is_false = false; if (const auto * const_arg = checkAndGetColumn(*arg_cond.column)) { cond_is_const = true; not_const_condition = const_arg->getDataColumnPtr(); ColumnPtr data_column = const_arg->getDataColumnPtr(); if (const auto * const_nullable_arg = checkAndGetColumn(*data_column)) { data_column = const_nullable_arg->getNestedColumnPtr(); if (!data_column->empty()) cond_is_null = const_nullable_arg->getNullMapData()[0]; } if (!data_column->empty()) { cond_is_true = !cond_is_null && checkAndGetColumn(*data_column)->getBool(0); cond_is_false = !cond_is_null && !cond_is_true; } } const auto & column1 = arguments[1]; const auto & column2 = arguments[2]; if (cond_is_true) return castColumn(column1, result_type); else if (cond_is_false || cond_is_null) return castColumn(column2, result_type); if (const auto * nullable = checkAndGetColumn(*not_const_condition)) { ColumnPtr new_cond_column = nullable->getNestedColumnPtr(); size_t column_size = arg_cond.column->size(); if (checkAndGetColumn(*new_cond_column)) { auto nested_column_copy = new_cond_column->cloneResized(new_cond_column->size()); typeid_cast(nested_column_copy.get())->applyZeroMap(nullable->getNullMapData()); new_cond_column = std::move(nested_column_copy); if (cond_is_const) new_cond_column = ColumnConst::create(new_cond_column, column_size); } else throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of {} condition", arg_cond.column->getName(), getName()); ColumnsWithTypeAndName temporary_columns { { new_cond_column, removeNullable(arg_cond.type), arg_cond.name }, column1, column2, }; return executeImpl(temporary_columns, result_type, new_cond_column->size()); } return nullptr; } template static ColumnPtr materializeColumnIfConst(const AnyColumnPtr & column) { return column->convertToFullColumnIfConst(); } static ColumnPtr makeNullableColumnIfNot(const ColumnPtr & column) { auto materialized = materializeColumnIfConst(column); if (isColumnNullable(*materialized)) return materialized; return ColumnNullable::create(materialized, ColumnUInt8::create(column->size(), 0)); } /// Return nested column recursively removing Nullable, examples: /// Nullable(size = 1, Int32(size = 1), UInt8(size = 1)) -> Int32(size = 1) /// Const(size = 0, Nullable(size = 1, Int32(size = 1), UInt8(size = 1))) -> /// Const(size = 0, Int32(size = 1)) static ColumnPtr recursiveGetNestedColumnWithoutNullable(const ColumnPtr & column) { if (const auto * nullable = checkAndGetColumn(*column)) { /// Nullable cannot contain Nullable return nullable->getNestedColumnPtr(); } else if (const auto * column_const = checkAndGetColumn(*column)) { /// Save Constant, but remove Nullable return ColumnConst::create(recursiveGetNestedColumnWithoutNullable(column_const->getDataColumnPtr()), column->size()); } return column; } ColumnPtr executeForNullableThenElse(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { /// If result type is Variant, we don't need to remove Nullable. if (isVariant(result_type)) return nullptr; const ColumnWithTypeAndName & arg_cond = arguments[0]; const ColumnWithTypeAndName & arg_then = arguments[1]; const ColumnWithTypeAndName & arg_else = arguments[2]; const auto * then_is_nullable = checkAndGetColumn(*arg_then.column); const auto * else_is_nullable = checkAndGetColumn(*arg_else.column); if (!then_is_nullable && !else_is_nullable) return nullptr; /** Calculate null mask of result and nested column separately. */ ColumnPtr result_null_mask; { ColumnsWithTypeAndName temporary_columns( { arg_cond, { then_is_nullable ? then_is_nullable->getNullMapColumnPtr() : DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count), std::make_shared(), "" }, { else_is_nullable ? else_is_nullable->getNullMapColumnPtr() : DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count), std::make_shared(), "" } }); result_null_mask = executeImpl(temporary_columns, std::make_shared(), input_rows_count); } ColumnPtr result_nested_column; { ColumnsWithTypeAndName temporary_columns( { arg_cond, { recursiveGetNestedColumnWithoutNullable(arg_then.column), removeNullable(arg_then.type), "" }, { recursiveGetNestedColumnWithoutNullable(arg_else.column), removeNullable(arg_else.type), "" } }); result_nested_column = executeImpl(temporary_columns, removeNullable(result_type), temporary_columns.front().column->size()); } return ColumnNullable::create( materializeColumnIfConst(result_nested_column), materializeColumnIfConst(result_null_mask)); } ColumnPtr executeForNullThenElse(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { const ColumnWithTypeAndName & arg_cond = arguments[0]; const ColumnWithTypeAndName & arg_then = arguments[1]; const ColumnWithTypeAndName & arg_else = arguments[2]; bool then_is_null = arg_then.column->onlyNull(); bool else_is_null = arg_else.column->onlyNull(); if (!then_is_null && !else_is_null) return nullptr; if (then_is_null && else_is_null) return result_type->createColumnConstWithDefaultValue(input_rows_count); bool then_is_short = arg_then.column->size() < arg_cond.column->size(); bool else_is_short = arg_else.column->size() < arg_cond.column->size(); const ColumnUInt8 * cond_col = typeid_cast(arg_cond.column.get()); const ColumnConst * cond_const_col = checkAndGetColumnConst>(arg_cond.column.get()); /// If then is NULL, we create Nullable column with null mask OR-ed with condition. if (then_is_null) { ColumnPtr arg_else_column; /// In case when arg_else column type differs with result /// column type we should cast it to result type. if (removeNullable(arg_else.type)->getName() != removeNullable(result_type)->getName()) arg_else_column = castColumn(arg_else, result_type); else arg_else_column = arg_else.column; if (cond_col) { arg_else_column = arg_else_column->convertToFullColumnIfConst(); auto result_column = IColumn::mutate(std::move(arg_else_column)); if (else_is_short) result_column->expand(cond_col->getData(), true); if (isColumnNullable(*result_column)) { assert_cast(*result_column).applyNullMap(assert_cast(*arg_cond.column)); return result_column; } else if (auto * variant_column = typeid_cast(result_column.get())) { variant_column->applyNullMap(assert_cast(*arg_cond.column).getData()); return result_column; } else return ColumnNullable::create(materializeColumnIfConst(result_column), arg_cond.column); } else if (cond_const_col) { if (cond_const_col->getValue()) return result_type->createColumn()->cloneResized(input_rows_count); else return makeNullableColumnIfNot(arg_else_column); } else throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}. " "Must be ColumnUInt8 or ColumnConstUInt8.", arg_cond.column->getName(), getName()); } /// If else is NULL, we create Nullable column with null mask OR-ed with negated condition. if (else_is_null) { ColumnPtr arg_then_column; /// In case when arg_then column type differs with result /// column type we should cast it to result type. if (removeNullable(arg_then.type)->getName() != removeNullable(result_type)->getName()) arg_then_column = castColumn(arg_then, result_type); else arg_then_column = arg_then.column; if (cond_col) { arg_then_column = arg_then_column->convertToFullColumnIfConst(); auto result_column = IColumn::mutate(std::move(arg_then_column)); if (then_is_short) result_column->expand(cond_col->getData(), false); if (isColumnNullable(*result_column)) { assert_cast(*result_column).applyNegatedNullMap(assert_cast(*arg_cond.column)); return result_column; } else if (auto * variant_column = typeid_cast(result_column.get())) { variant_column->applyNegatedNullMap(assert_cast(*arg_cond.column).getData()); return result_column; } else { size_t size = input_rows_count; const auto & null_map_data = cond_col->getData(); auto negated_null_map = ColumnUInt8::create(); auto & negated_null_map_data = negated_null_map->getData(); negated_null_map_data.resize(size); for (size_t i = 0; i < size; ++i) negated_null_map_data[i] = !null_map_data[i]; return ColumnNullable::create(materializeColumnIfConst(result_column), std::move(negated_null_map)); } } else if (cond_const_col) { if (cond_const_col->getValue()) return makeNullableColumnIfNot(arg_then_column); else return result_type->createColumn()->cloneResized(input_rows_count); } else throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}. " "Must be ColumnUInt8 or ColumnConstUInt8.", arg_cond.column->getName(), getName()); } return nullptr; } static void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) { int last_short_circuit_argument_index = checkShortCircuitArguments(arguments); if (last_short_circuit_argument_index == -1) return; executeColumnIfNeeded(arguments[0]); /// Check if condition is const or null to not create full mask from it. if ((isColumnConst(*arguments[0].column) || arguments[0].column->onlyNull()) && !arguments[0].column->empty()) { bool value = arguments[0].column->getBool(0); executeColumnIfNeeded(arguments[1], !value); executeColumnIfNeeded(arguments[2], value); return; } IColumn::Filter mask(arguments[0].column->size(), 1); auto mask_info = extractMask(mask, arguments[0].column); maskedExecute(arguments[1], mask, mask_info); inverseMask(mask, mask_info); maskedExecute(arguments[2], mask, mask_info); } public: String getName() const override { return name; } size_t getNumberOfArguments() const override { return 3; } bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForNothing() const override { return false; } bool isShortCircuit(ShortCircuitSettings & settings, size_t /*number_of_arguments*/) const override { settings.arguments_with_disabled_lazy_execution.insert(0); settings.enable_lazy_execution_for_common_descendants_of_arguments = false; settings.force_enable_lazy_execution = false; return true; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } ColumnNumbers getArgumentsThatDontImplyNullableReturnType(size_t /*number_of_arguments*/) const override { return {0}; } bool canBeExecutedOnLowCardinalityDictionary() const override { return false; } /// Get result types by argument types. If the function does not apply to these arguments, throw an exception. DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { if (arguments[0]->onlyNull()) return arguments[2]; if (arguments[0]->isNullable()) return getReturnTypeImpl({ removeNullable(arguments[0]), arguments[1], arguments[2]}); if (!WhichDataType(arguments[0]).isUInt8()) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of first argument (condition) of function if. " "Must be UInt8.", arguments[0]->getName()); if (use_variant_when_no_common_type) return getLeastSupertypeOrVariant(DataTypes{arguments[1], arguments[2]}); return getLeastSupertype(DataTypes{arguments[1], arguments[2]}); } ColumnPtr executeImpl(const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count) const override { ColumnsWithTypeAndName arguments = args; executeShortCircuitArguments(arguments); ColumnPtr res; if ( (res = executeForConstAndNullableCondition(arguments, result_type, input_rows_count)) || (res = executeForNullThenElse(arguments, result_type, input_rows_count)) || (res = executeForNullableThenElse(arguments, result_type, input_rows_count))) return res; const ColumnWithTypeAndName & arg_cond = arguments[0]; const ColumnWithTypeAndName & arg_then = arguments[1]; const ColumnWithTypeAndName & arg_else = arguments[2]; /// A case for identical then and else (pointers are the same). if (arg_then.column.get() == arg_else.column.get()) { /// Just point result to them. return arg_then.column; } const ColumnUInt8 * cond_col = typeid_cast(arg_cond.column.get()); const ColumnConst * cond_const_col = checkAndGetColumnConst>(arg_cond.column.get()); ColumnPtr materialized_cond_col; if (cond_const_col) { UInt8 value = cond_const_col->getValue(); const ColumnWithTypeAndName & arg = value ? arg_then : arg_else; if (arg.type->equals(*result_type)) return arg.column; else return castColumn(arg, result_type); } if (!cond_col) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}. " "Must be ColumnUInt8 or ColumnConstUInt8.", arg_cond.column->getName(), getName()); auto call = [&](const auto & types) -> bool { using Types = std::decay_t; using T0 = typename Types::LeftType; using T1 = typename Types::RightType; res = executeTyped(cond_col, arguments, result_type, input_rows_count); return res != nullptr; }; DataTypePtr left_type = arg_then.type; DataTypePtr right_type = arg_else.type; if (const auto * left_array = checkAndGetDataType(arg_then.type.get())) left_type = left_array->getNestedType(); if (const auto * right_array = checkAndGetDataType(arg_else.type.get())) right_type = right_array->getNestedType(); /// Special case when one column is Integer and another is UInt64 that can be actually Int64. /// The result type for this case is Int64 and we need to change UInt64 type to Int64 /// so the NumberTraits::ResultOfIf will return Int64 instead if Int128. if (isNativeInteger(left_type) && isUInt64ThatCanBeInt64(right_type)) right_type = std::make_shared(); else if (isNativeInteger(right_type) && isUInt64ThatCanBeInt64(left_type)) left_type = std::make_shared(); TypeIndex left_id = left_type->getTypeId(); TypeIndex right_id = right_type->getTypeId(); /// TODO optimize for map type /// TODO optimize for nullable type if (!(callOnBasicTypes(left_id, right_id, call) || (res = executeTyped(cond_col, arguments, result_type, input_rows_count)) || (res = executeString(cond_col, arguments, result_type)) || (res = executeGenericArray(cond_col, arguments, result_type)) || (res = executeTuple(arguments, result_type, input_rows_count)) || (res = executeMap(arguments, result_type, input_rows_count)))) { return executeGeneric(cond_col, arguments, input_rows_count, use_variant_when_no_common_type); } return res; } ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override { const ColumnWithTypeAndName & arg_cond = arguments[0]; if (!arg_cond.column || !isColumnConst(*arg_cond.column)) return {}; const ColumnConst * cond_const_col = checkAndGetColumnConst>(arg_cond.column.get()); if (!cond_const_col) return {}; bool condition_value = cond_const_col->getValue(); const ColumnWithTypeAndName & arg_then = arguments[1]; const ColumnWithTypeAndName & arg_else = arguments[2]; const ColumnWithTypeAndName & potential_const_column = condition_value ? arg_then : arg_else; if (!potential_const_column.column || !isColumnConst(*potential_const_column.column)) return {}; auto result = castColumn(potential_const_column, result_type); if (!isColumnConst(*result)) return {}; return result; } }; } REGISTER_FUNCTION(If) { factory.registerFunction({}, FunctionFactory::CaseInsensitive); } FunctionOverloadResolverPtr createInternalFunctionIfOverloadResolver(bool allow_experimental_variant_type, bool use_variant_as_common_type) { return std::make_unique(std::make_shared(allow_experimental_variant_type && use_variant_as_common_type)); } }