functions: function array now accept arguments of different types [METR-10799]

This commit is contained in:
Sergey Fedorov 2014-04-10 21:28:06 +04:00
parent 71d146d345
commit a67d4f7e11

View File

@ -10,6 +10,8 @@
#include <DB/Interpreters/HashMap.h>
#include <DB/Interpreters/ClearableHashMap.h>
#include <DB/Interpreters/AggregationCommon.h>
#include <DB/Functions/NumberTraits.h>
#include <DB/Functions/FunctionsConditional.h>
#include <unordered_map>
@ -30,26 +32,115 @@ namespace DB
* Например: arrayEnumerateUniq([10, 20, 10, 30]) = [1, 1, 2, 1]
*/
class FunctionArray : public IFunction
{
public:
private:
/// Получить имя функции.
String getName() const
{
return "array";
}
template <typename T0, typename T1>
bool checkRightType(DataTypePtr left, DataTypePtr right, DataTypePtr & type_res) const
{
if (dynamic_cast<const T1 *>(&*right))
{
typedef typename NumberTraits::ResultOfIf<typename T0::FieldType, typename T1::FieldType>::Type ResultType;
type_res = DataTypeFromFieldTypeOrError<ResultType>::getDataType();
if (!type_res)
throw Exception("Arguments of function " + getName() + " are not upscalable to a common type without loss of precision.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return true;
}
return false;
}
template <typename T0>
bool checkLeftType(DataTypePtr left, DataTypePtr right, DataTypePtr & type_res) const
{
if (dynamic_cast<const T0 *>(&*left))
{
if ( checkRightType<T0, DataTypeUInt8>(left, right, type_res)
|| checkRightType<T0, DataTypeUInt16>(left, right, type_res)
|| checkRightType<T0, DataTypeUInt32>(left, right, type_res)
|| checkRightType<T0, DataTypeUInt64>(left, right, type_res)
|| checkRightType<T0, DataTypeInt8>(left, right, type_res)
|| checkRightType<T0, DataTypeInt16>(left, right, type_res)
|| checkRightType<T0, DataTypeInt32>(left, right, type_res)
|| checkRightType<T0, DataTypeInt64>(left, right, type_res)
|| checkRightType<T0, DataTypeFloat32>(left, right, type_res)
|| checkRightType<T0, DataTypeFloat64>(left, right, type_res))
return true;
else
throw Exception("Illegal type " + right->getName() + " as argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return false;
}
template <typename T0, typename T1>
bool tryAddField(DataTypePtr type_res, const Field & f, Array & arr) const
{
if (dynamic_cast<const T0 *>(&*type_res))
{
arr.push_back(apply_visitor(FieldVisitorConvertToNumber<typename T1::FieldType>(), f));
return true;
}
return false;
}
bool addField(DataTypePtr type_res, const Field & f, Array & arr) const
{
if ( tryAddField<DataTypeUInt8, DataTypeUInt64>(type_res, f, arr)
|| tryAddField<DataTypeUInt16, DataTypeUInt64>(type_res, f, arr)
|| tryAddField<DataTypeUInt32, DataTypeUInt64>(type_res, f, arr)
|| tryAddField<DataTypeUInt64, DataTypeUInt64>(type_res, f, arr)
|| tryAddField<DataTypeInt8, DataTypeInt64>(type_res, f, arr)
|| tryAddField<DataTypeInt16, DataTypeInt64>(type_res, f, arr)
|| tryAddField<DataTypeInt32, DataTypeInt64>(type_res, f, arr)
|| tryAddField<DataTypeInt64, DataTypeInt64>(type_res, f, arr)
|| tryAddField<DataTypeFloat32, DataTypeFloat64>(type_res, f, arr)
|| tryAddField<DataTypeFloat64, DataTypeFloat64>(type_res, f, arr) )
return true;
else
throw Exception("Illegal result type " + type_res->getName() + " of function " + getName(),
ErrorCodes::LOGICAL_ERROR);
}
DataTypePtr getLeastCommonType(DataTypePtr left, DataTypePtr right) const
{
DataTypePtr type_res;
if (!( checkLeftType<DataTypeUInt8>(left, right, type_res)
|| checkLeftType<DataTypeUInt16>(left, right, type_res)
|| checkLeftType<DataTypeUInt32>(left, right, type_res)
|| checkLeftType<DataTypeUInt64>(left, right, type_res)
|| checkLeftType<DataTypeInt8>(left, right, type_res)
|| checkLeftType<DataTypeInt16>(left, right, type_res)
|| checkLeftType<DataTypeInt32>(left, right, type_res)
|| checkLeftType<DataTypeInt64>(left, right, type_res)
|| checkLeftType<DataTypeFloat32>(left, right, type_res)
|| checkLeftType<DataTypeFloat64>(left, right, type_res)))
throw Exception("Internal error: unexpected type " + left->getName() + " as argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return type_res;
}
public:
/// Получить тип результата по типам аргументов. Если функция неприменима для данных аргументов - кинуть исключение.
DataTypePtr getReturnType(const DataTypes & arguments) const
{
if (arguments.empty())
throw Exception("Function array requires at least one argument.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
DataTypePtr result_type = arguments[0];
for (size_t i = 1, size = arguments.size(); i < size; ++i)
if (arguments[i]->getName() != arguments[0]->getName())
throw Exception("Arguments for function array must have same type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
result_type = getLeastCommonType(result_type, arguments[i]);
return new DataTypeArray(arguments[0]);
return new DataTypeArray(result_type);
}
/// Выполнить функцию над блоком.
@ -60,11 +151,15 @@ public:
if (!block.getByPosition(arguments[i]).column->isConst())
throw Exception("Arguments for function array must be constant.", ErrorCodes::ILLEGAL_COLUMN);
DataTypePtr result_type = block.getByPosition(arguments[0]).type;
for (size_t i = 1, size = arguments.size(); i < size; ++i)
result_type = getLeastCommonType(result_type, block.getByPosition(arguments[i]).type);
Array arr;
for (size_t i = 0, size = arguments.size(); i < size; ++i)
arr.push_back((*block.getByPosition(arguments[i]).column)[0]);
addField(result_type, (*block.getByPosition(arguments[i]).column)[0], arr);
block.getByPosition(result).column = new ColumnConstArray(block.getByPosition(arguments[0]).column->size(), arr, new DataTypeArray(block.getByPosition(arguments[0]).type));
block.getByPosition(result).column = new ColumnConstArray(block.getByPosition(arguments[0]).column->size(), arr, new DataTypeArray(result_type));
}
};