Added support for tuple arguments of function if [#METR-22033].

This commit is contained in:
Alexey Milovidov 2016-07-09 06:54:57 +03:00
parent fb3a28f646
commit 27360b8166
4 changed files with 145 additions and 8 deletions

View File

@ -4,11 +4,13 @@
#include <DB/DataTypes/DataTypeArray.h>
#include <DB/DataTypes/DataTypeString.h>
#include <DB/DataTypes/DataTypeFixedString.h>
#include <DB/DataTypes/DataTypeTuple.h>
#include <DB/Columns/ColumnVector.h>
#include <DB/Columns/ColumnString.h>
#include <DB/Columns/ColumnConst.h>
#include <DB/Columns/ColumnArray.h>
#include <DB/Columns/ColumnFixedString.h>
#include <DB/Columns/ColumnTuple.h>
#include <DB/Functions/IFunction.h>
#include <DB/Functions/NumberTraits.h>
#include <DB/Functions/DataTypeTraits.h>
@ -817,6 +819,7 @@ struct StringArrayIfImpl
}
};
class FunctionIf : public IFunction
{
public:
@ -1259,6 +1262,49 @@ private:
return false;
}
bool executeTuple(const ColumnUInt8 * cond_col, Block & block, const ColumnNumbers & arguments, size_t result)
{
/// Calculate function for each corresponding elements of tuples.
const ColumnWithTypeAndName & arg1 = block.getByPosition(arguments[1]);
const ColumnWithTypeAndName & arg2 = block.getByPosition(arguments[2]);
const ColumnTuple * col1 = static_cast<const ColumnTuple *>(arg1.column.get());
const ColumnTuple * col2 = static_cast<const ColumnTuple *>(arg2.column.get());
if (!col1 || !col2)
return false;
const DataTypeTuple & type1 = static_cast<const DataTypeTuple &>(*arg1.type);
const DataTypeTuple & type2 = static_cast<const DataTypeTuple &>(*arg2.type);
Block temporary_block;
temporary_block.insert(block.getByPosition(arguments[0]));
size_t tuple_size = type1.getElements().size();
for (size_t i = 0; i < tuple_size; ++i)
{
temporary_block.insert({nullptr,
getReturnType({std::make_shared<DataTypeUInt8>(), type1.getElements()[i], type2.getElements()[i]}),
{}});
temporary_block.insert({col1->getData().getByPosition(i).column, type1.getElements()[i], {}});
temporary_block.insert({col2->getData().getByPosition(i).column, type2.getElements()[i], {}});
/// temporary_block will be: cond, res_0, ..., res_i, then_i, else_i
execute(temporary_block, {0, i + 2, i + 3}, i + 1);
temporary_block.erase(i + 3);
temporary_block.erase(i + 2);
}
/// temporary_block is: cond, res_0, res_1, res_2...
temporary_block.erase(0);
block.getByPosition(result).column = std::make_shared<ColumnTuple>(temporary_block);
return true;
}
public:
/// Получить имя функции.
String getName() const override
@ -1281,6 +1327,9 @@ public:
const DataTypeArray * type_arr1 = typeid_cast<const DataTypeArray *>(arguments[1].get());
const DataTypeArray * type_arr2 = typeid_cast<const DataTypeArray *>(arguments[2].get());
const DataTypeTuple * type_tuple1 = typeid_cast<const DataTypeTuple *>(arguments[1].get());
const DataTypeTuple * type_tuple2 = typeid_cast<const DataTypeTuple *>(arguments[2].get());
if (arguments[1]->behavesAsNumber() && arguments[2]->behavesAsNumber())
{
DataTypePtr type_res;
@ -1303,6 +1352,21 @@ public:
/// NOTE Сообщения об ошибках будут относится к типам элементов массивов, что немного некорректно.
return std::make_shared<DataTypeArray>(getReturnType({arguments[0], type_arr1->getNestedType(), type_arr2->getNestedType()}));
}
else if (type_tuple1 && type_tuple2)
{
const size_t tuple_size = type_tuple1->getElements().size();
if (tuple_size != type_tuple2->getElements().size())
throw Exception("Different sizes of tuples in 'then' and 'else' argument of function if",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypes result_tuple(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
result_tuple[i] = getReturnType({arguments[0], type_tuple1->getElements()[i], type_tuple2->getElements()[i]});
return std::make_shared<DataTypeTuple>(std::move(result_tuple));
}
else if (arguments[1]->getName() != arguments[2]->getName())
{
const DataTypeString * type_string1 = typeid_cast<const DataTypeString *>(arguments[1].get());
@ -1340,14 +1404,16 @@ public:
const ColumnConst<UInt8> * cond_const_col = typeid_cast<const ColumnConst<UInt8> *>(&*block.getByPosition(arguments[0]).column);
ColumnPtr materialized_cond_col;
const ColumnWithTypeAndName & arg_then = block.getByPosition(arguments[1]);
const ColumnWithTypeAndName & arg_else = block.getByPosition(arguments[2]);
if (cond_const_col)
{
if (block.getByPosition(arguments[1]).type->getName() ==
block.getByPosition(arguments[2]).type->getName())
if (arg_then.type->getName() == arg_else.type->getName())
{
block.getByPosition(result).column = cond_const_col->getData()
? block.getByPosition(arguments[1]).column
: block.getByPosition(arguments[2]).column;
? arg_then.column
: arg_else.column;
return;
}
else
@ -1369,9 +1435,10 @@ public:
|| executeLeftType<Int64>(cond_col, block, arguments, result)
|| executeLeftType<Float32>(cond_col, block, arguments, result)
|| executeLeftType<Float64>(cond_col, block, arguments, result)
|| executeString(cond_col, block, arguments, result)))
throw Exception("Illegal columns " + block.getByPosition(arguments[1]).column->getName()
+ " and " + block.getByPosition(arguments[2]).column->getName()
|| executeString(cond_col, block, arguments, result)
|| executeTuple(cond_col, block, arguments, result)))
throw Exception("Illegal columns " + arg_then.column->getName()
+ " and " + arg_else.column->getName()
+ " of second (then) and third (else) arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}

View File

@ -156,7 +156,9 @@ void Block::erase(size_t position)
+ toString(index_by_position.size() - 1), ErrorCodes::POSITION_OUT_OF_BOUND);
Container_t::iterator it = index_by_position[position];
index_by_name.erase(index_by_name.find(it->name));
auto index_by_name_it = index_by_name.find(it->name);
if (index_by_name.end() != index_by_name_it)
index_by_name.erase(index_by_name_it);
data.erase(it);
for (size_t i = position, size = index_by_position.size() - 1; i < size; ++i)

View File

@ -0,0 +1,60 @@
(0,'! 0')
(10,'! 1')
(2,'2')
(30,'! 3')
(40,'! 4')
(5,'5')
(60,'! 6')
(70,'! 7')
(8,'8')
(90,'! 9')
(0,'! 0')
(10,'! 1')
(20,'! 2')
(30,'! 3')
(40,'! 4')
(50,'! 5')
(60,'! 6')
(70,'! 7')
(80,'! 8')
(90,'! 9')
(0,'0')
(1,'1')
(2,'2')
(3,'3')
(4,'4')
(5,'5')
(6,'6')
(7,'7')
(8,'8')
(9,'9')
(2,'World')
(2,'World')
(1,'Hello')
(2,'World')
(2,'World')
(1,'Hello')
(2,'World')
(2,'World')
(1,'Hello')
(2,'World')
(0,'World')
(0,'World')
(2,'Hello')
(0,'World')
(0,'World')
(5,'Hello')
(0,'World')
(0,'World')
(8,'Hello')
(0,'World')
(0,'1')
(0,'2')
(2,'Hello')
(0,'8')
(0,'16')
(5,'Hello')
(0,'64')
(0,'128')
(8,'Hello')
(0,'512')

View File

@ -0,0 +1,8 @@
SELECT number % 3 = 2 ? (number, toString(number)) : (number * 10, concat('! ', toString(number))) FROM system.numbers LIMIT 10;
SELECT 0 ? (number, toString(number)) : (number * 10, concat('! ', toString(number))) FROM system.numbers LIMIT 10;
SELECT 1 ? (number, toString(number)) : (number * 10, concat('! ', toString(number))) FROM system.numbers LIMIT 10;
SELECT number % 3 = 2 ? (1, 'Hello') : (2, 'World') FROM system.numbers LIMIT 10;
SELECT number % 3 = 2 ? (number, 'Hello') : (0, 'World') FROM system.numbers LIMIT 10;
SELECT number % 3 = 2 ? (number, 'Hello') : (0, toString(exp2(number))) FROM system.numbers LIMIT 10;