dbms: add support for comparing enums [#METR-19265]

This commit is contained in:
Andrey Mironov 2015-12-22 15:03:21 +03:00
parent 3fb8fedd19
commit 02ce1bede2
5 changed files with 140 additions and 131 deletions

View File

@ -5,6 +5,8 @@
#include <DB/DataTypes/DataTypeString.h>
#include <DB/DataTypes/DataTypeFixedString.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
#include <DB/DataTypes/DataTypeEnum.h>
namespace DB
{
@ -24,6 +26,8 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionTemplate<Int64>;
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionTemplate<Float32>;
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionTemplate<Float64>;
else if (typeid_cast<const DataTypeEnum8 *>(&argument_type)) return new AggregateFunctionTemplate<UInt8>;
else if (typeid_cast<const DataTypeEnum16 *>(&argument_type)) return new AggregateFunctionTemplate<UInt16>;
else
return nullptr;
}
@ -41,6 +45,8 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionTemplate<Int64, Data>;
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionTemplate<Float32, Data>;
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionTemplate<Float64, Data>;
else if (typeid_cast<const DataTypeEnum8 *>(&argument_type)) return new AggregateFunctionTemplate<UInt8, Data>;
else if (typeid_cast<const DataTypeEnum16 *>(&argument_type)) return new AggregateFunctionTemplate<UInt16, Data>;
else
return nullptr;
}
@ -59,6 +65,8 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
else if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionTemplate<Int64, Data<Int64> >;
else if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionTemplate<Float32, Data<Float32> >;
else if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionTemplate<Float64, Data<Float64> >;
else if (typeid_cast<const DataTypeEnum8 *>(&argument_type)) return new AggregateFunctionTemplate<UInt8, Data<UInt8> >;
else if (typeid_cast<const DataTypeEnum16 *>(&argument_type)) return new AggregateFunctionTemplate<UInt16, Data<UInt16> >;
else
return nullptr;
}
@ -79,6 +87,8 @@ static IAggregateFunction * createWithTwoNumericTypesSecond(const IDataType & se
else if (typeid_cast<const DataTypeInt64 *>(&second_type)) return new AggregateFunctionTemplate<FirstType, Int64>;
else if (typeid_cast<const DataTypeFloat32 *>(&second_type)) return new AggregateFunctionTemplate<FirstType, Float32>;
else if (typeid_cast<const DataTypeFloat64 *>(&second_type)) return new AggregateFunctionTemplate<FirstType, Float64>;
else if (typeid_cast<const DataTypeEnum8 *>(&second_type)) return new AggregateFunctionTemplate<FirstType, UInt8>;
else if (typeid_cast<const DataTypeEnum16 *>(&second_type)) return new AggregateFunctionTemplate<FirstType, UInt16>;
else
return nullptr;
}
@ -96,6 +106,8 @@ static IAggregateFunction * createWithTwoNumericTypes(const IDataType & first_ty
else if (typeid_cast<const DataTypeInt64 *>(&first_type)) return createWithTwoNumericTypesSecond<Int64, AggregateFunctionTemplate>(second_type);
else if (typeid_cast<const DataTypeFloat32 *>(&first_type)) return createWithTwoNumericTypesSecond<Float32, AggregateFunctionTemplate>(second_type);
else if (typeid_cast<const DataTypeFloat64 *>(&first_type)) return createWithTwoNumericTypesSecond<Float64, AggregateFunctionTemplate>(second_type);
else if (typeid_cast<const DataTypeEnum8 *>(&first_type)) return createWithTwoNumericTypesSecond<UInt8, AggregateFunctionTemplate>(second_type);
else if (typeid_cast<const DataTypeEnum16 *>(&first_type)) return createWithTwoNumericTypesSecond<UInt16, AggregateFunctionTemplate>(second_type);
else
return nullptr;
}

View File

@ -83,13 +83,19 @@ public:
fillMap();
}
std::string getName() const override
{
return name;
}
std::string getName() const override { return name; }
bool isNumeric() const override { return true; }
bool behavesAsNumber() const override { return true; }
/// Returns length of textual name for an enum element (used in FunctionVisibleWidth)
std::size_t getNameLength(const FieldType & value) const
{
return getNameForValue(value).size();
}
std::string getNameForValue(const FieldType & value) const
{
const auto it = std::lower_bound(std::begin(values), std::end(values), value, [] (const auto & left, const auto & right) {
return left.second < right;
@ -101,7 +107,19 @@ public:
ErrorCodes::LOGICAL_ERROR
};
return it->first.size();
return it->first;
}
FieldType getValue(const std::string & name) const
{
const auto it = map.find(StringRef{name});
if (it == std::end(map))
throw Exception{
"Unknown string '" + name + "' for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
return it->second;
}
DataTypePtr clone() const override
@ -124,115 +142,43 @@ public:
void serializeText(const Field & field, WriteBuffer & ostr) const override
{
const FieldType x = get<typename NearestFieldType<FieldType>::Type>(field);
const auto it = std::lower_bound(std::begin(values), std::end(values), x, [] (const auto & left, const auto & right) {
return left.second < right;
});
if (it == std::end(values) || it->second != x)
throw Exception{
"Unexpected value " + toString(x) + " for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
writeString(it->first, ostr);
writeString(getNameForValue(x), ostr);
}
void deserializeText(Field & field, ReadBuffer & istr) const override
{
std::string name;
readString(name, istr);
const auto it = map.find(StringRef{name});
if (it == std::end(map))
throw Exception{
"Unknown string '" + name + "' for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
field = nearestFieldType(it->second);
field = nearestFieldType(getValue(name));
}
void serializeTextEscaped(const Field & field, WriteBuffer & ostr) const override
{
const FieldType x = get<typename NearestFieldType<FieldType>::Type>(field);
const auto it = std::lower_bound(std::begin(values), std::end(values), x, [] (const auto & left, const auto & right) {
return left.second < right;
});
if (it == std::end(values) || it->second != x)
throw Exception{
"Unexpected value " + toString(x) + " for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
writeEscapedString(it->first, ostr);
writeEscapedString(getNameForValue(x), ostr);
}
void deserializeTextEscaped(Field & field, ReadBuffer & istr) const override
{
field.assignString("", 0);
std::string name;
readEscapedString(name, istr);
const auto it = map.find(StringRef{name});
if (it == std::end(map))
throw Exception{
"Unknown string '" + name + "' for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
field = nearestFieldType(it->second);
field = nearestFieldType(getValue(name));
}
void serializeTextQuoted(const Field & field, WriteBuffer & ostr) const override
{
const FieldType x = get<typename NearestFieldType<FieldType>::Type>(field);
const auto it = std::lower_bound(std::begin(values), std::end(values), x, [] (const auto & left, const auto & right) {
return left.second < right;
});
if (it == std::end(values) || it->second != x)
throw Exception{
"Unexpected value " + toString(x) + " for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
writeQuotedString(it->first, ostr);
writeQuotedString(getNameForValue(x), ostr);
}
void deserializeTextQuoted(Field & field, ReadBuffer & istr) const override
{
field.assignString("", 0);
std::string name;
readQuotedString(name, istr);
const auto it = map.find(StringRef{name});
if (it == std::end(map))
throw Exception{
"Unknown string '" + name + "' for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
field = nearestFieldType(it->second);
field = nearestFieldType(getValue(name));
}
void serializeTextJSON(const Field & field, WriteBuffer & ostr) const override
{
const FieldType x = get<typename NearestFieldType<FieldType>::Type>(field);
const auto it = std::lower_bound(std::begin(values), std::end(values), x, [] (const auto & left, const auto & right) {
return left.second < right;
});
if (it == std::end(values) || it->second != x)
throw Exception{
"Unexpected value " + toString(x) + " for " + getName(),
ErrorCodes::LOGICAL_ERROR
};
writeJSONString(it->first, ostr);
writeJSONString(getNameForValue(x), ostr);
}
/** Потоковая сериализация массивов устроена по-особенному:

View File

@ -14,6 +14,7 @@
#include <DB/Functions/FunctionsLogical.h>
#include <DB/Functions/IFunction.h>
#include <DB/DataTypes/DataTypeEnum.h>
namespace DB
@ -576,31 +577,32 @@ private:
}
}
void executeDateOrDateTimeWithConstString(Block & block, size_t result,
const IColumn * col_left_untyped, const IColumn * col_right_untyped,
bool left_is_num, bool right_is_num)
void executeDateOrDateTimeOrEnumWithConstString(
Block & block, size_t result, const IColumn * col_left_untyped, const IColumn * col_right_untyped,
const DataTypePtr & left_type, const DataTypePtr & right_type, bool left_is_num, bool right_is_num)
{
/// Особый случай - сравнение дат и дат-с-временем со строковой константой.
const IColumn * column_date_or_datetime = left_is_num ? col_left_untyped : col_right_untyped;
/// Уже не такой и особый случай - сравнение дат, дат-с-временем и перечислений со строковой константой.
const IColumn * column_string_untyped = !left_is_num ? col_left_untyped : col_right_untyped;
const IColumn * column_number = left_is_num ? col_left_untyped : col_right_untyped;
const IDataType * number_type = left_is_num ? left_type.get() : right_type.get();
bool is_date = false;
bool is_date_time = false;
bool is_enum8 = false;
bool is_enum16 = false;
is_date = typeid_cast<const ColumnVector<DataTypeDate::FieldType> *>(column_date_or_datetime)
|| typeid_cast<const ColumnConst<DataTypeDate::FieldType> *>(column_date_or_datetime);
const auto legal_types = (is_date = typeid_cast<const DataTypeDate *>(number_type))
|| (is_date_time = typeid_cast<const DataTypeDateTime *>(number_type))
|| (is_enum8 = typeid_cast<const DataTypeEnum8 *>(number_type))
|| (is_enum16 = typeid_cast<const DataTypeEnum16 *>(number_type));
if (!is_date)
is_date_time = typeid_cast<const ColumnVector<DataTypeDateTime::FieldType> *>(column_date_or_datetime)
|| typeid_cast<const ColumnConst<DataTypeDateTime::FieldType> *>(column_date_or_datetime);
const ColumnConstString * column_string = typeid_cast<const ColumnConstString *>(column_string_untyped);
if (!column_string
|| (!is_date && !is_date_time))
throw Exception("Illegal columns " + col_left_untyped->getName() + " and " + col_right_untyped->getName()
+ " of arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
const auto column_string = typeid_cast<const ColumnConstString *>(column_string_untyped);
if (!column_string || !legal_types)
throw Exception{
"Illegal columns " + col_left_untyped->getName() + " and " + col_right_untyped->getName()
+ " of arguments of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN
};
if (is_date)
{
@ -628,6 +630,28 @@ private:
left_is_num ? col_left_untyped : &parsed_const_date_time,
left_is_num ? &parsed_const_date_time : col_right_untyped);
}
else if (is_enum8)
executeEnumWithConstString<DataTypeUInt8::FieldType>(block, result, column_number, column_string,
number_type, left_is_num);
else if (is_enum16)
executeEnumWithConstString<DataTypeEnum16::FieldType>(block, result, column_number, column_string,
number_type, left_is_num);
}
/// Comparison between DataTypeEnum<T> and string constant containing the name of an enum element
template <typename FieldType>
void executeEnumWithConstString(
Block & block, const size_t result, const IColumn * column_number, const ColumnConstString * column_string,
const IDataType * type_untyped, const bool left_is_num)
{
const auto type = static_cast<const DataTypeEnum<FieldType> *>(type_untyped);
const Field x = nearestFieldType(type->getValue(column_string->getData()));
const auto enum_col = type->createConstColumn(block.rowsInFirstColumn(), x);
executeNumLeftType<FieldType>(block, result,
left_is_num ? column_number : enum_col.get(),
left_is_num ? enum_col.get() : column_number);
}
void executeTuple(Block & block, size_t result, const IColumn * c0, const IColumn * c1)
@ -749,6 +773,8 @@ public:
bool left_is_date = false;
bool left_is_date_time = false;
bool left_is_enum8 = false;
bool left_is_enum16 = false;
bool left_is_string = false;
bool left_is_fixed_string = false;
const DataTypeTuple * left_tuple = nullptr;
@ -756,12 +782,18 @@ public:
false
|| (left_is_date = typeid_cast<const DataTypeDate *>(arguments[0].get()))
|| (left_is_date_time = typeid_cast<const DataTypeDateTime *>(arguments[0].get()))
|| (left_is_enum8 = typeid_cast<const DataTypeEnum8 *>(arguments[0].get()))
|| (left_is_enum16 = typeid_cast<const DataTypeEnum16 *>(arguments[0].get()))
|| (left_is_string = typeid_cast<const DataTypeString *>(arguments[0].get()))
|| (left_is_fixed_string = typeid_cast<const DataTypeFixedString *>(arguments[0].get()))
|| (left_tuple = typeid_cast<const DataTypeTuple *>(arguments[0].get()));
const bool left_is_enum = left_is_enum8 || left_is_enum16;
bool right_is_date = false;
bool right_is_date_time = false;
bool right_is_enum8 = false;
bool right_is_enum16 = false;
bool right_is_string = false;
bool right_is_fixed_string = false;
const DataTypeTuple * right_tuple = nullptr;
@ -769,18 +801,28 @@ public:
false
|| (right_is_date = typeid_cast<const DataTypeDate *>(arguments[1].get()))
|| (right_is_date_time = typeid_cast<const DataTypeDateTime *>(arguments[1].get()))
|| (right_is_enum8 = typeid_cast<const DataTypeEnum8 *>(arguments[1].get()))
|| (right_is_enum16 = typeid_cast<const DataTypeEnum16 *>(arguments[1].get()))
|| (right_is_string = typeid_cast<const DataTypeString *>(arguments[1].get()))
|| (right_is_fixed_string = typeid_cast<const DataTypeFixedString *>(arguments[1].get()))
|| (right_tuple = typeid_cast<const DataTypeTuple *>(arguments[1].get()));
const bool right_is_enum = right_is_enum8 || right_is_enum16;
if (!( (arguments[0]->behavesAsNumber() && arguments[1]->behavesAsNumber())
|| ((left_is_string || left_is_fixed_string) && (right_is_string || right_is_fixed_string))
|| (left_is_date && right_is_date)
|| (left_is_date && right_is_string) /// Можно сравнивать дату и дату-с-временем с константной строкой.
|| (left_is_date && right_is_string) /// Можно сравнивать дату, дату-с-временем и перечисление с константной строкой.
|| (left_is_string && right_is_date)
|| (left_is_date_time && right_is_date_time)
|| (left_is_date_time && right_is_string)
|| (left_is_string && right_is_date_time)
|| (left_is_date_time && right_is_date_time)
|| (left_is_date_time && right_is_string)
|| (left_is_string && right_is_date_time)
|| (left_is_enum && right_is_enum && arguments[0]->getName() == arguments[1]->getName()) /// only equivalent enum type values can be compared against
|| (left_is_enum && right_is_string)
|| (left_is_string && right_is_enum)
|| (left_tuple && right_tuple && left_tuple->getElements().size() == right_tuple->getElements().size())))
throw Exception("Illegal types of arguments (" + arguments[0]->getName() + ", " + arguments[1]->getName() + ")"
" of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -798,11 +840,13 @@ public:
/// Выполнить функцию над блоком.
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override
{
const IColumn * col_left_untyped = block.getByPosition(arguments[0]).column.get();
const IColumn * col_right_untyped = block.getByPosition(arguments[1]).column.get();
const auto & col_with_name_and_type_left = block.getByPosition(arguments[0]);
const auto & col_with_name_and_type_right = block.getByPosition(arguments[1]);
const IColumn * col_left_untyped = col_with_name_and_type_left.column.get();
const IColumn * col_right_untyped = col_with_name_and_type_right.column.get();
bool left_is_num = col_left_untyped->isNumeric();
bool right_is_num = col_right_untyped->isNumeric();
const bool left_is_num = col_left_untyped->isNumeric();
const bool right_is_num = col_right_untyped->isNumeric();
if (left_is_num && right_is_num)
{
@ -820,22 +864,16 @@ public:
+ " of first argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
else if (typeid_cast<const ColumnTuple *>(col_left_untyped))
executeTuple(block, result, col_left_untyped, col_right_untyped);
else if (!left_is_num && !right_is_num)
executeString(block, result, col_left_untyped, col_right_untyped);
else
{
if (typeid_cast<const ColumnTuple *>(col_left_untyped))
{
executeTuple(block, result, col_left_untyped, col_right_untyped);
}
else if (!left_is_num && !right_is_num)
{
executeString(block, result, col_left_untyped, col_right_untyped);
}
else
{
executeDateOrDateTimeWithConstString(block, result, col_left_untyped, col_right_untyped, left_is_num, right_is_num);
}
}
}
executeDateOrDateTimeOrEnumWithConstString(
block, result, col_left_untyped, col_right_untyped,
col_with_name_and_type_left.type, col_with_name_and_type_right.type,
left_is_num, right_is_num);
}
};

View File

@ -132,7 +132,8 @@ DataTypePtr DataTypeFactory::get(const String & name) const
}
for (size_t i = 1; i < args_list.children.size(); ++i)
argument_types.push_back(get(args_list.children[i]->getColumnName()));
argument_types.push_back(get(
std::string{args_list.children[i]->range.first, args_list.children[i]->range.second}));
function = AggregateFunctionFactory().get(function_name, argument_types);
if (!params_row.empty())
@ -181,12 +182,11 @@ DataTypePtr DataTypeFactory::get(const String & name) const
return new DataTypeTuple(elems);
}
/// @todo ParserUnsignedInteger fails if number is at the end of line, append space
if (base_name == "Enum8")
return parseEnum<DataTypeEnum8>(name, base_name, parameters + ' ');
return parseEnum<DataTypeEnum8>(name, base_name, parameters);
if (base_name == "Enum16")
return parseEnum<DataTypeEnum16>(name, base_name, parameters + ' ');
return parseEnum<DataTypeEnum16>(name, base_name, parameters);
throw Exception("Unknown type " + base_name, ErrorCodes::UNKNOWN_TYPE);
}

View File

@ -22,6 +22,7 @@
#include <DB/DataTypes/DataTypeFixedString.h>
#include <DB/DataTypes/DataTypeDate.h>
#include <DB/DataTypes/DataTypeDateTime.h>
#include <DB/DataTypes/DataTypeEnum.h>
namespace DB
@ -285,11 +286,19 @@ static Field convertToType(const Field & src, const IDataType & type)
if (typeid_cast<const DataTypeFloat32 *>(&type)) return convertNumericType<Float32>(src, type);
if (typeid_cast<const DataTypeFloat64 *>(&type)) return convertNumericType<Float64>(src, type);
bool is_date = typeid_cast<const DataTypeDate *>(&type);
bool is_datetime = typeid_cast<const DataTypeDateTime *>(&type);
const bool is_date = typeid_cast<const DataTypeDate *>(&type);
bool is_datetime = false;
bool is_enum8 = false;
bool is_enum16 = false;
if (!is_date && !is_datetime)
throw Exception("Logical error: unknown numeric type " + type.getName(), ErrorCodes::LOGICAL_ERROR);
if (!is_date)
if (!(is_datetime = typeid_cast<const DataTypeDateTime *>(&type)))
if (!(is_enum8 = typeid_cast<const DataTypeEnum8 *>(&type)))
if (!(is_enum16 = typeid_cast<const DataTypeEnum16 *>(&type)))
throw Exception{
"Logical error: unknown numeric type " + type.getName(),
ErrorCodes::LOGICAL_ERROR
};
if (src.getType() == Field::Types::UInt64)
return src;
@ -309,7 +318,7 @@ static Field convertToType(const Field & src, const IDataType & type)
return Field(UInt64(date));
}
else
else if (is_datetime)
{
time_t date_time{};
readDateTimeText(date_time, in);
@ -318,6 +327,10 @@ static Field convertToType(const Field & src, const IDataType & type)
return Field(UInt64(date_time));
}
else if (is_enum8)
return Field(UInt64(static_cast<const DataTypeEnum8 &>(type).getValue(str)));
else if (is_enum16)
return Field(UInt64(static_cast<const DataTypeEnum16 &>(type).getValue(str)));
}
throw Exception("Type mismatch in IN section: " + type.getName() + " at left, "