#include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } /// Function multiIf, which generalizes the function if. /// /// Syntax: multiIf(cond_1, then_1, ..., cond_N, then_N, else) /// where N >= 1. /// /// For all 1 <= i <= N, "cond_i" has type UInt8. /// Types of all the branches "then_i" and "else" are either of the following: /// - numeric types for which there exists a common type; /// - dates; /// - dates with time; /// - strings; /// - arrays of such types. /// /// Additionally the arguments, conditions or branches, support nullable types /// and the NULL value, with a NULL condition treated as false. class FunctionMultiIf final : public FunctionIfBase { public: static constexpr auto name = "multiIf"; static FunctionPtr create(const Context & context) { return std::make_shared(context); } FunctionMultiIf(const Context & context) : context(context) {} public: String getName() const override { return name; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } bool useDefaultImplementationForNulls() const override { return false; } DataTypePtr getReturnTypeImpl(const DataTypes & args) const override { /// Arguments are the following: cond1, then1, cond2, then2, ... condN, thenN, else. auto for_conditions = [&args](auto && f) { size_t conditions_end = args.size() - 1; for (size_t i = 0; i < conditions_end; i += 2) f(args[i]); }; auto for_branches = [&args](auto && f) { size_t branches_end = args.size(); for (size_t i = 1; i < branches_end; i += 2) f(args[i]); f(args.back()); }; if (!(args.size() >= 3 && args.size() % 2 == 1)) throw Exception{"Invalid number of arguments for function " + getName(), ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; /// Conditions must be UInt8, Nullable(UInt8) or Null. If one of conditions is Nullable, the result is also Nullable. bool have_nullable_condition = false; for_conditions([&](const DataTypePtr & arg) { const IDataType * nested_type; if (arg->isNullable()) { have_nullable_condition = true; if (arg->onlyNull()) return; const DataTypeNullable & nullable_type = static_cast(*arg); nested_type = nullable_type.getNestedType().get(); } else { nested_type = arg.get(); } if (!WhichDataType(nested_type).isUInt8()) throw Exception{"Illegal type " + arg->getName() + " of argument (condition) " "of function " + getName() + ". Must be UInt8.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; }); DataTypes types_of_branches; types_of_branches.reserve(args.size() / 2 + 1); for_branches([&](const DataTypePtr & arg) { types_of_branches.emplace_back(arg); }); DataTypePtr common_type_of_branches = getLeastSupertype(types_of_branches); return have_nullable_condition ? makeNullable(common_type_of_branches) : common_type_of_branches; } void executeImpl(Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count) override { /** We will gather values from columns in branches to result column, * depending on values of conditions. */ struct Instruction { const IColumn * condition = nullptr; const IColumn * source = nullptr; bool condition_always_true = false; bool condition_is_nullable = false; bool source_is_constant = false; }; std::vector instructions; instructions.reserve(args.size() / 2 + 1); Columns converted_columns_holder; converted_columns_holder.reserve(instructions.size()); const DataTypePtr & return_type = block.getByPosition(result).type; for (size_t i = 0; i < args.size(); i += 2) { Instruction instruction; size_t source_idx = i + 1; if (source_idx == args.size()) { /// The last, "else" branch can be treated as a branch with always true condition "else if (true)". --source_idx; instruction.condition_always_true = true; } else { const ColumnWithTypeAndName & cond_col = block.getByPosition(args[i]); /// We skip branches that are always false. /// If we encounter a branch that is always true, we can finish. if (cond_col.column->onlyNull()) continue; if (cond_col.column->isColumnConst()) { Field value = typeid_cast(*cond_col.column).getField(); if (value.isNull()) continue; if (value.get() == 0) continue; instruction.condition_always_true = true; } else { if (cond_col.column->isColumnNullable()) instruction.condition_is_nullable = true; instruction.condition = cond_col.column.get(); } } const ColumnWithTypeAndName & source_col = block.getByPosition(args[source_idx]); if (source_col.type->equals(*return_type)) { instruction.source = source_col.column.get(); } else { /// Cast all columns to result type. converted_columns_holder.emplace_back(castColumn(source_col, return_type, context)); instruction.source = converted_columns_holder.back().get(); } if (instruction.source && instruction.source->isColumnConst()) instruction.source_is_constant = true; instructions.emplace_back(std::move(instruction)); if (instructions.back().condition_always_true) break; } size_t rows = input_rows_count; MutableColumnPtr res = return_type->createColumn(); for (size_t i = 0; i < rows; ++i) { for (const auto & instruction : instructions) { bool insert = false; if (instruction.condition_always_true) insert = true; else if (!instruction.condition_is_nullable) insert = static_cast(*instruction.condition).getData()[i]; else { const ColumnNullable & condition_nullable = static_cast(*instruction.condition); const ColumnUInt8 & condition_nested = static_cast(condition_nullable.getNestedColumn()); const NullMap & condition_null_map = condition_nullable.getNullMapData(); insert = !condition_null_map[i] && condition_nested.getData()[i]; } if (insert) { if (!instruction.source_is_constant) res->insertFrom(*instruction.source, i); else res->insertFrom(static_cast(*instruction.source).getDataColumn(), 0); break; } } } block.getByPosition(result).column = std::move(res); } private: const Context & context; }; void registerFunctionMultiIf(FunctionFactory & factory) { factory.registerFunction(); /// These are obsolete function names. factory.registerFunction("caseWithoutExpr"); factory.registerFunction("caseWithoutExpression"); } }