Updated implementation

This commit is contained in:
Maksim Kita 2023-09-19 13:06:19 +03:00
parent 29e4352c17
commit c7ddbab9bc
4 changed files with 275 additions and 10 deletions

231
src/DataTypes/Utils.cpp Normal file
View File

@ -0,0 +1,231 @@
#include <DataTypes/Utils.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeTuple.h>
namespace DB
{
bool canBeSafelyCasted(const DataTypePtr & from_type, const DataTypePtr & to_type)
{
auto from_which_type = WhichDataType(from_type->getTypeId());
bool to_type_was_nullable = isNullableOrLowCardinalityNullable(to_type);
auto to_type_unwrapped = removeNullable(removeLowCardinality(to_type));
if (from_type->equals(*to_type))
return true;
auto to_which_type = WhichDataType(to_type->getTypeId());
switch (from_which_type.idx)
{
case TypeIndex::UInt8:
case TypeIndex::UInt16:
case TypeIndex::UInt32:
case TypeIndex::UInt64:
case TypeIndex::UInt128:
case TypeIndex::UInt256:
{
if (to_which_type.isUInt() &&
to_type_unwrapped->getSizeOfValueInMemory() >= from_type->getSizeOfValueInMemory())
return true;
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::Int8:
case TypeIndex::Int16:
case TypeIndex::Int32:
case TypeIndex::Int64:
case TypeIndex::Int128:
case TypeIndex::Int256:
{
if (to_which_type.isInt() &&
to_type_unwrapped->getSizeOfValueInMemory() >= from_type->getSizeOfValueInMemory())
return true;
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::Float32:
{
if (to_which_type.isFloat64() || to_which_type.isString())
return true;
return false;
}
case TypeIndex::Float64:
case TypeIndex::Date:
case TypeIndex::Date32:
case TypeIndex::DateTime:
case TypeIndex::DateTime64:
case TypeIndex::FixedString:
case TypeIndex::Enum8:
case TypeIndex::Enum16:
case TypeIndex::IPv6:
{
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::Decimal32:
case TypeIndex::Decimal64:
case TypeIndex::Decimal128:
case TypeIndex::Decimal256:
{
if (to_which_type.isDecimal())
{
auto from_type_decimal_precision = getDecimalPrecision(*from_type);
auto to_type_decimal_precision = getDecimalPrecision(*to_type_unwrapped);
if (from_type_decimal_precision > to_type_decimal_precision)
return false;
auto from_type_decimal_scale = getDecimalScale(*from_type);
auto to_type_decimal_scale = getDecimalScale(*to_type_unwrapped);
if (from_type_decimal_scale > to_type_decimal_scale)
return false;
return true;
}
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::UUID:
{
if (to_which_type.isUInt128() || to_which_type.isString())
return true;
return false;
}
case TypeIndex::IPv4:
{
if (to_which_type.isUInt32() || to_which_type.isUInt64() || to_which_type.isString())
return true;
return false;
}
case TypeIndex::Nullable:
{
if (to_type_was_nullable)
{
const auto & from_type_nullable = assert_cast<const DataTypeNullable &>(*from_type);
return canBeSafelyCasted(from_type_nullable.getNestedType(), to_type_unwrapped);
}
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::LowCardinality:
{
const auto & from_type_low_cardinality = assert_cast<const DataTypeLowCardinality &>(*from_type);
return canBeSafelyCasted(from_type_low_cardinality.getDictionaryType(), to_type_unwrapped);
}
case TypeIndex::Array:
{
if (to_which_type.isArray())
{
const auto & from_type_array = assert_cast<const DataTypeArray &>(*from_type);
const auto & to_type_array = assert_cast<const DataTypeArray &>(*to_type_unwrapped);
return canBeSafelyCasted(from_type_array.getNestedType(), to_type_array.getNestedType());
}
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::Map:
{
if (to_which_type.isMap())
{
const auto & from_type_map = assert_cast<const DataTypeMap &>(*from_type);
const auto & to_type_map = assert_cast<const DataTypeMap &>(*to_type_unwrapped);
if (!canBeSafelyCasted(from_type_map.getKeyType(), to_type_map.getKeyType()))
return false;
if (!canBeSafelyCasted(from_type_map.getValueType(), to_type_map.getValueType()))
return false;
return true;
}
if (to_which_type.isArray())
{
// Map nested type is Array(Tuple(key_type, value_type))
const auto & from_type_map = assert_cast<const DataTypeMap &>(*from_type);
const auto & to_type_array = assert_cast<const DataTypeArray &>(*to_type_unwrapped);
const auto * to_type_nested_tuple_type = typeid_cast<const DataTypeTuple *>(to_type_array.getNestedType().get());
if (!to_type_nested_tuple_type)
return false;
const auto & to_type_tuple_elements = to_type_nested_tuple_type->getElements();
if (to_type_tuple_elements.size() != 2)
return false;
if (!canBeSafelyCasted(from_type_map.getKeyType(), to_type_tuple_elements[0]))
return false;
if (!canBeSafelyCasted(from_type_map.getValueType(), to_type_tuple_elements[1]))
return false;
return true;
}
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::Tuple:
{
if (to_which_type.isTuple())
{
const auto & from_type_tuple = assert_cast<const DataTypeTuple &>(*from_type);
const auto & to_type_tuple = assert_cast<const DataTypeTuple &>(*to_type_unwrapped);
const auto & from_tuple_type_elements = from_type_tuple.getElements();
const auto & to_tuple_type_elements = to_type_tuple.getElements();
size_t lhs_type_elements_size = from_tuple_type_elements.size();
if (lhs_type_elements_size != to_tuple_type_elements.size())
return false;
for (size_t i = 0; i < lhs_type_elements_size; ++i)
if (!canBeSafelyCasted(from_tuple_type_elements[i], to_tuple_type_elements[i]))
return false;
return true;
}
if (to_which_type.isString())
return true;
return false;
}
case TypeIndex::String:
case TypeIndex::Object:
case TypeIndex::Set:
case TypeIndex::Interval:
case TypeIndex::Function:
case TypeIndex::AggregateFunction:
case TypeIndex::Nothing:
return false;
}
return true;
}
}

19
src/DataTypes/Utils.h Normal file
View File

@ -0,0 +1,19 @@
#pragma once
#include <DataTypes/IDataType.h>
namespace DB
{
/** Returns true if from_type can be safely casted to to_type.
*
* Examples:
* From type UInt8 to type UInt16 returns true.
* From type UInt16 to type UInt8 returns false.
* From type String to type LowCardinality(String) returns true.
* From type LowCardinality(String) to type String returns true.
* From type String to type UInt8 returns false.
*/
bool canBeSafelyCasted(const DataTypePtr & from_type, const DataTypePtr & to_type);
}

View File

@ -6,7 +6,6 @@
#include <Columns/ColumnNullable.h> #include <Columns/ColumnNullable.h>
#include <Columns/ColumnLowCardinality.h> #include <Columns/ColumnLowCardinality.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <DataTypes/DataTypeNullable.h>
namespace DB namespace DB

View File

@ -3,6 +3,7 @@
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/FieldToDataType.h> #include <DataTypes/FieldToDataType.h>
#include <DataTypes/getLeastSupertype.h> #include <DataTypes/getLeastSupertype.h>
#include <DataTypes/Utils.h>
#include <Interpreters/TreeRewriter.h> #include <Interpreters/TreeRewriter.h>
#include <Interpreters/ExpressionAnalyzer.h> #include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h> #include <Interpreters/ExpressionActions.h>
@ -1257,7 +1258,7 @@ bool KeyCondition::tryPrepareSetIndex(
if (!future_set) if (!future_set)
return false; return false;
const auto & set_types = future_set->getTypes(); const auto set_types = future_set->getTypes();
size_t set_types_size = set_types.size(); size_t set_types_size = set_types.size();
size_t indexes_mapping_size = indexes_mapping.size(); size_t indexes_mapping_size = indexes_mapping.size();
@ -1283,24 +1284,37 @@ bool KeyCondition::tryPrepareSetIndex(
for (size_t indexes_mapping_index = 0; indexes_mapping_index < indexes_mapping_size; ++indexes_mapping_index) for (size_t indexes_mapping_index = 0; indexes_mapping_index < indexes_mapping_size; ++indexes_mapping_index)
{ {
const auto & key_column_type = data_types[indexes_mapping_index];
size_t set_element_index = indexes_mapping[indexes_mapping_index].tuple_index; size_t set_element_index = indexes_mapping[indexes_mapping_index].tuple_index;
const auto & set_element_type = set_types[set_element_index]; auto set_element_type = set_types[set_element_index];
auto & set_column = set_columns[set_element_index]; auto set_column = set_columns[set_element_index];
bool is_set_column_nullable = set_element_type->isNullable(); if (canBeSafelyCasted(set_element_type, key_column_type))
bool is_set_column_low_cardinality_nullable = set_element_type->isLowCardinalityNullable(); {
set_columns[set_element_index] = castColumn({set_column, set_element_type, {}}, key_column_type);
continue;
}
if (!key_column_type->canBeInsideNullable())
return false;
const NullMap * set_column_null_map = nullptr; const NullMap * set_column_null_map = nullptr;
if (is_set_column_nullable || is_set_column_low_cardinality_nullable) if (isNullableOrLowCardinalityNullable(set_element_type))
{ {
if (is_set_column_low_cardinality_nullable) if (WhichDataType(set_element_type).isLowCardinality())
{
set_element_type = removeLowCardinality(set_element_type);
set_column = set_column->convertToFullColumnIfLowCardinality(); set_column = set_column->convertToFullColumnIfLowCardinality();
}
set_column_null_map = &assert_cast<const ColumnNullable &>(*set_column).getNullMapData(); set_element_type = removeNullable(set_element_type);
const auto & set_column_nullable = assert_cast<const ColumnNullable &>(*set_column);
set_column_null_map = &set_column_nullable.getNullMapData();
set_column = set_column_nullable.getNestedColumnPtr();
} }
auto nullable_set_column = castColumnAccurateOrNull({set_column, set_element_type, {}}, data_types[indexes_mapping_index]); auto nullable_set_column = castColumnAccurateOrNull({set_column, set_element_type, {}}, key_column_type);
const auto & nullable_set_column_typed = assert_cast<const ColumnNullable &>(*nullable_set_column); const auto & nullable_set_column_typed = assert_cast<const ColumnNullable &>(*nullable_set_column);
const auto & nullable_set_column_null_map = nullable_set_column_typed.getNullMapData(); const auto & nullable_set_column_null_map = nullable_set_column_typed.getNullMapData();
size_t nullable_set_column_null_map_size = nullable_set_column_null_map.size(); size_t nullable_set_column_null_map_size = nullable_set_column_null_map.size();
@ -1321,6 +1335,8 @@ bool KeyCondition::tryPrepareSetIndex(
set_column = nullable_set_column_typed.getNestedColumn().filter(filter, 0); set_column = nullable_set_column_typed.getNestedColumn().filter(filter, 0);
} }
set_columns[set_element_index] = std::move(set_column);
} }
out.set_index = std::make_shared<MergeTreeSetIndex>(set_columns, std::move(indexes_mapping)); out.set_index = std::make_shared<MergeTreeSetIndex>(set_columns, std::move(indexes_mapping));