#pragma once #include #include #include #include #include #include #include namespace DB { /** Функция выбора по условию: if(cond, then, else). * cond - UInt8 * then, else - либо числа/даты/даты-с-временем, либо строки. */ template struct NumIfImpl { private: static PODArray & result_vector(Block & block, size_t result, size_t size) { ColumnVector * col_res = new ColumnVector; block.getByPosition(result).column = col_res; typename ColumnVector::Container_t & vec_res = col_res->getData(); vec_res.resize(size); return vec_res; } public: static void vector_vector( const PODArray & cond, const PODArray & a, const PODArray & b, Block & block, size_t result) { size_t size = cond.size(); PODArray & res = result_vector(block, result, size); for (size_t i = 0; i < size; ++i) res[i] = cond[i] ? static_cast(a[i]) : static_cast(b[i]); } static void vector_constant( const PODArray & cond, const PODArray & a, B b, Block & block, size_t result) { size_t size = cond.size(); PODArray & res = result_vector(block, result, size); for (size_t i = 0; i < size; ++i) res[i] = cond[i] ? static_cast(a[i]) : static_cast(b); } static void constant_vector( const PODArray & cond, A a, const PODArray & b, Block & block, size_t result) { size_t size = cond.size(); PODArray & res = result_vector(block, result, size); for (size_t i = 0; i < size; ++i) res[i] = cond[i] ? static_cast(a) : static_cast(b[i]); } static void constant_constant( const PODArray & cond, A a, B b, Block & block, size_t result) { size_t size = cond.size(); PODArray & res = result_vector(block, result, size); for (size_t i = 0; i < size; ++i) res[i] = cond[i] ? static_cast(a) : static_cast(b); } }; template struct NumIfImpl { private: static void throw_error() { throw Exception("Internal logic error: invalid types of arguments 2 and 3 of if", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } public: static void vector_vector( const PODArray & cond, const PODArray & a, const PODArray & b, Block & block, size_t result) { throw_error(); } static void vector_constant( const PODArray & cond, const PODArray & a, B b, Block & block, size_t result) { throw_error(); } static void constant_vector( const PODArray & cond, A a, const PODArray & b, Block & block, size_t result) { throw_error(); } static void constant_constant( const PODArray & cond, A a, B b, Block & block, size_t result) { throw_error(); } }; struct StringIfImpl { static void vector_vector( const PODArray & cond, const ColumnString::Chars_t & a_data, const ColumnString::Offsets_t & a_offsets, const ColumnString::Chars_t & b_data, const ColumnString::Offsets_t & b_offsets, ColumnString::Chars_t & c_data, ColumnString::Offsets_t & c_offsets) { size_t size = cond.size(); c_offsets.resize(size); c_data.reserve(std::max(a_data.size(), b_data.size())); ColumnString::Offset_t a_prev_offset = 0; ColumnString::Offset_t b_prev_offset = 0; ColumnString::Offset_t c_prev_offset = 0; for (size_t i = 0; i < size; ++i) { if (cond[i]) { size_t size_to_write = a_offsets[i] - a_prev_offset; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], &a_data[a_prev_offset], size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } else { size_t size_to_write = b_offsets[i] - b_prev_offset; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], &b_data[b_prev_offset], size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } a_prev_offset = a_offsets[i]; b_prev_offset = b_offsets[i]; } } static void vector_constant( const PODArray & cond, const ColumnString::Chars_t & a_data, const ColumnString::Offsets_t & a_offsets, const String & b, ColumnString::Chars_t & c_data, ColumnString::Offsets_t & c_offsets) { size_t size = cond.size(); c_offsets.resize(size); c_data.reserve(a_data.size()); ColumnString::Offset_t a_prev_offset = 0; ColumnString::Offset_t c_prev_offset = 0; for (size_t i = 0; i < size; ++i) { if (cond[i]) { size_t size_to_write = a_offsets[i] - a_prev_offset; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], &a_data[a_prev_offset], size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } else { size_t size_to_write = b.size() + 1; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], b.data(), size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } a_prev_offset = a_offsets[i]; } } static void constant_vector( const PODArray & cond, const String & a, const ColumnString::Chars_t & b_data, const ColumnString::Offsets_t & b_offsets, ColumnString::Chars_t & c_data, ColumnString::Offsets_t & c_offsets) { size_t size = cond.size(); c_offsets.resize(size); c_data.reserve(b_data.size()); ColumnString::Offset_t b_prev_offset = 0; ColumnString::Offset_t c_prev_offset = 0; for (size_t i = 0; i < size; ++i) { if (cond[i]) { size_t size_to_write = a.size() + 1; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], a.data(), size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } else { size_t size_to_write = b_offsets[i] - b_prev_offset; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], &b_data[b_prev_offset], size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } b_prev_offset = b_offsets[i]; } } static void constant_constant( const PODArray & cond, const String & a, const String & b, ColumnString::Chars_t & c_data, ColumnString::Offsets_t & c_offsets) { size_t size = cond.size(); c_offsets.resize(size); c_data.reserve((std::max(a.size(), b.size()) + 1) * size); ColumnString::Offset_t c_prev_offset = 0; for (size_t i = 0; i < size; ++i) { if (cond[i]) { size_t size_to_write = a.size() + 1; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], a.data(), size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } else { size_t size_to_write = b.size() + 1; c_data.resize(c_data.size() + size_to_write); memcpy(&c_data[c_prev_offset], b.data(), size_to_write); c_prev_offset += size_to_write; c_offsets[i] = c_prev_offset; } } } }; template struct DataTypeFromFieldTypeOrError { static DataTypePtr getDataType() { return new typename DataTypeFromFieldType::Type; } }; template <> struct DataTypeFromFieldTypeOrError { static DataTypePtr getDataType() { return NULL; } }; class FunctionIf : public IFunction { private: template bool checkRightType(const DataTypes & arguments, DataTypePtr & type_res) const { if (dynamic_cast(&*arguments[2])) { typedef typename NumberTraits::ResultOfIf::Type ResultType; type_res = DataTypeFromFieldTypeOrError::getDataType(); if (!type_res) throw Exception("Arguments 2 and 3 of function " + getName() + " are not upscalable to a common type without loss of precision: " + arguments[1]->getName() + " and " + arguments[2]->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return true; } return false; } template bool checkLeftType(const DataTypes & arguments, DataTypePtr & type_res) const { if (dynamic_cast(&*arguments[1])) { if ( checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res) || checkRightType(arguments, type_res)) return true; else throw Exception("Illegal type " + arguments[2]->getName() + " of third argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } return false; } template bool executeRightType( const ColumnVector * cond_col, Block & block, const ColumnNumbers & arguments, size_t result, const ColumnVector * col_left) { ColumnVector * col_right_vec = dynamic_cast *>(&*block.getByPosition(arguments[2]).column); ColumnConst * col_right_const = dynamic_cast *>(&*block.getByPosition(arguments[2]).column); if (!col_right_vec && !col_right_const) return false; typedef typename NumberTraits::ResultOfIf::Type ResultType; if (col_right_vec) NumIfImpl::vector_vector(cond_col->getData(), col_left->getData(), col_right_vec->getData(), block, result); else NumIfImpl::vector_constant(cond_col->getData(), col_left->getData(), col_right_const->getData(), block, result); return true; } template bool executeConstRightType( const ColumnVector * cond_col, Block & block, const ColumnNumbers & arguments, size_t result, const ColumnConst * col_left) { ColumnVector * col_right_vec = dynamic_cast *>(&*block.getByPosition(arguments[2]).column); ColumnConst * col_right_const = dynamic_cast *>(&*block.getByPosition(arguments[2]).column); if (!col_right_vec && !col_right_const) return false; typedef typename NumberTraits::ResultOfIf::Type ResultType; if (col_right_vec) NumIfImpl::constant_vector(cond_col->getData(), col_left->getData(), col_right_vec->getData(), block, result); else NumIfImpl::constant_constant(cond_col->getData(), col_left->getData(), col_right_const->getData(), block, result); return true; } template bool executeLeftType(const ColumnVector * cond_col, Block & block, const ColumnNumbers & arguments, size_t result) { if (ColumnVector * col_left = dynamic_cast *>(&*block.getByPosition(arguments[1]).column)) { if ( executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left) || executeRightType(cond_col, block, arguments, result, col_left)) return true; else throw Exception("Illegal column " + block.getByPosition(arguments[2]).column->getName() + " of third argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); } else if (ColumnConst * col_left = dynamic_cast *>(&*block.getByPosition(arguments[1]).column)) { if ( executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left) || executeConstRightType(cond_col, block, arguments, result, col_left)) return true; else throw Exception("Illegal column " + block.getByPosition(arguments[2]).column->getName() + " of third argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); } return false; } bool executeString(const ColumnVector * cond_col, Block & block, const ColumnNumbers & arguments, size_t result) { ColumnString * col_then = dynamic_cast(&*block.getByPosition(arguments[1]).column); ColumnString * col_else = dynamic_cast(&*block.getByPosition(arguments[2]).column); ColumnConstString * col_then_const = dynamic_cast(&*block.getByPosition(arguments[1]).column); ColumnConstString * col_else_const = dynamic_cast(&*block.getByPosition(arguments[2]).column); ColumnString * col_res = new ColumnString; block.getByPosition(result).column = col_res; ColumnString::Chars_t & res_vec = col_res->getChars(); ColumnString::Offsets_t & res_offsets = col_res->getOffsets(); if (col_then && col_else) StringIfImpl::vector_vector( cond_col->getData(), col_then->getChars(), col_then->getOffsets(), col_else->getChars(), col_else->getOffsets(), res_vec, res_offsets); else if (col_then && col_else_const) StringIfImpl::vector_constant( cond_col->getData(), col_then->getChars(), col_then->getOffsets(), col_else_const->getData(), res_vec, res_offsets); else if (col_then_const && col_else) StringIfImpl::constant_vector( cond_col->getData(), col_then_const->getData(), col_else->getChars(), col_else->getOffsets(), res_vec, res_offsets); else if (col_then_const && col_else_const) StringIfImpl::constant_constant( cond_col->getData(), col_then_const->getData(), col_else_const->getData(), res_vec, res_offsets); else return false; return true; } public: /// Получить имя функции. String getName() const { return "if"; } /// Получить типы результата по типам аргументов. Если функция неприменима для данных аргументов - кинуть исключение. DataTypePtr getReturnType(const DataTypes & arguments) const { if (arguments.size() != 3) throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + ", should be 3.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!dynamic_cast(&*arguments[0])) throw Exception("Illegal type of first argument (condition) of function if. Must be UInt8.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (arguments[1]->behavesAsNumber() && arguments[2]->behavesAsNumber()) { DataTypePtr type_res; if (!( checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res) || checkLeftType(arguments, type_res))) throw Exception("Internal error: unexpected type " + arguments[1]->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return type_res; } else if (arguments[1]->getName() != arguments[2]->getName()) { throw Exception("Incompatible second and third arguments for function " + getName() + ": " + arguments[1]->getName() + " and " + arguments[2]->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } return arguments[1]; } /// Выполнить функцию над блоком. void execute(Block & block, const ColumnNumbers & arguments, size_t result) { const ColumnVector * cond_col = dynamic_cast *>(&*block.getByPosition(arguments[0]).column); const ColumnConst * cond_const_col = dynamic_cast *>(&*block.getByPosition(arguments[0]).column); ColumnPtr materialized_cond_col; if (cond_const_col) { if (block.getByPosition(arguments[1]).type->getName() == block.getByPosition(arguments[2]).type->getName()) { block.getByPosition(result).column = cond_const_col->getData() ? block.getByPosition(arguments[1]).column : block.getByPosition(arguments[2]).column; return; } else { materialized_cond_col = cond_const_col->convertToFullColumn(); cond_col = dynamic_cast *>(&*materialized_cond_col); } } if (cond_col) { if (!( executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeLeftType(cond_col, block, arguments, result) || executeString(cond_col, block, arguments, result))) throw Exception("Illegal columns " + block.getByPosition(arguments[1]).column->getName() + " and " + block.getByPosition(arguments[2]).column->getName() + " of second (then) and third (else) arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); } else throw Exception("Illegal column " + cond_col->getName() + " of first argument of function " + getName() + ". Must be ColumnUInt8 or ColumnConstUInt8.", ErrorCodes::ILLEGAL_COLUMN); } }; }