Add IDataType::forEachChild and use it in nested types validation

This commit is contained in:
avogar 2024-02-15 13:19:02 +00:00
parent 109720d162
commit efa823400b
14 changed files with 104 additions and 70 deletions

View File

@ -69,6 +69,11 @@ String DataTypeArray::doGetPrettyName(size_t indent) const
return s.str(); return s.str();
} }
void DataTypeArray::forEachChild(const ChildCallback & callback) const
{
callback(*nested);
nested->forEachChild(callback);
}
static DataTypePtr create(const ASTPtr & arguments) static DataTypePtr create(const ASTPtr & arguments)
{ {

View File

@ -43,6 +43,7 @@ public:
MutableColumnPtr createColumn() const override; MutableColumnPtr createColumn() const override;
void forEachChild(const ChildCallback & callback) const override;
Field getDefault() const override; Field getDefault() const override;

View File

@ -153,6 +153,12 @@ SerializationPtr DataTypeLowCardinality::doGetDefaultSerialization() const
return std::make_shared<SerializationLowCardinality>(dictionary_type); return std::make_shared<SerializationLowCardinality>(dictionary_type);
} }
void DataTypeLowCardinality::forEachChild(const ChildCallback & callback) const
{
callback(*dictionary_type);
dictionary_type->forEachChild(callback);
}
static DataTypePtr create(const ASTPtr & arguments) static DataTypePtr create(const ASTPtr & arguments)
{ {

View File

@ -60,6 +60,8 @@ public:
static MutableColumnUniquePtr createColumnUnique(const IDataType & keys_type); static MutableColumnUniquePtr createColumnUnique(const IDataType & keys_type);
static MutableColumnUniquePtr createColumnUnique(const IDataType & keys_type, MutableColumnPtr && keys); static MutableColumnUniquePtr createColumnUnique(const IDataType & keys_type, MutableColumnPtr && keys);
void forEachChild(const ChildCallback & callback) const override;
private: private:
SerializationPtr doGetDefaultSerialization() const override; SerializationPtr doGetDefaultSerialization() const override;

View File

@ -143,6 +143,14 @@ DataTypePtr DataTypeMap::getNestedTypeWithUnnamedTuple() const
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(from_tuple.getElements())); return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(from_tuple.getElements()));
} }
void DataTypeMap::forEachChild(const DB::IDataType::ChildCallback & callback) const
{
callback(*key_type);
key_type->forEachChild(callback);
callback(*value_type);
value_type->forEachChild(callback);
}
static DataTypePtr create(const ASTPtr & arguments) static DataTypePtr create(const ASTPtr & arguments)
{ {
if (!arguments || arguments->children.size() != 2) if (!arguments || arguments->children.size() != 2)

View File

@ -54,6 +54,8 @@ public:
static bool checkKeyType(DataTypePtr key_type); static bool checkKeyType(DataTypePtr key_type);
void forEachChild(const ChildCallback & callback) const override;
private: private:
void assertKeyType() const; void assertKeyType() const;
}; };

View File

@ -61,6 +61,12 @@ SerializationPtr DataTypeNullable::doGetDefaultSerialization() const
return std::make_shared<SerializationNullable>(nested_data_type->getDefaultSerialization()); return std::make_shared<SerializationNullable>(nested_data_type->getDefaultSerialization());
} }
void DataTypeNullable::forEachChild(const ChildCallback & callback) const
{
callback(*nested_data_type);
nested_data_type->forEachChild(callback);
}
static DataTypePtr create(const ASTPtr & arguments) static DataTypePtr create(const ASTPtr & arguments)
{ {

View File

@ -43,6 +43,9 @@ public:
bool canBePromoted() const override { return nested_data_type->canBePromoted(); } bool canBePromoted() const override { return nested_data_type->canBePromoted(); }
const DataTypePtr & getNestedType() const { return nested_data_type; } const DataTypePtr & getNestedType() const { return nested_data_type; }
void forEachChild(const ChildCallback & callback) const override;
private: private:
SerializationPtr doGetDefaultSerialization() const override; SerializationPtr doGetDefaultSerialization() const override;

View File

@ -376,6 +376,15 @@ SerializationInfoPtr DataTypeTuple::getSerializationInfo(const IColumn & column)
return std::make_shared<SerializationInfoTuple>(std::move(infos), names, SerializationInfo::Settings{}); return std::make_shared<SerializationInfoTuple>(std::move(infos), names, SerializationInfo::Settings{});
} }
void DataTypeTuple::forEachChild(const ChildCallback & callback) const
{
for (const auto & elem : elems)
{
callback(*elem);
elem->forEachChild(callback);
}
}
static DataTypePtr create(const ASTPtr & arguments) static DataTypePtr create(const ASTPtr & arguments)
{ {

View File

@ -70,6 +70,8 @@ public:
String getNameByPosition(size_t i) const; String getNameByPosition(size_t i) const;
bool haveExplicitNames() const { return have_explicit_names; } bool haveExplicitNames() const { return have_explicit_names; }
void forEachChild(const ChildCallback & callback) const override;
}; };
} }

View File

@ -175,6 +175,15 @@ SerializationPtr DataTypeVariant::doGetDefaultSerialization() const
return std::make_shared<SerializationVariant>(std::move(serializations), std::move(variant_names), SerializationVariant::getVariantsDeserializeTextOrder(variants), getName()); return std::make_shared<SerializationVariant>(std::move(serializations), std::move(variant_names), SerializationVariant::getVariantsDeserializeTextOrder(variants), getName());
} }
void DataTypeVariant::forEachChild(const DB::IDataType::ChildCallback & callback) const
{
for (const auto & variant : variants)
{
callback(*variant);
variant->forEachChild(callback);
}
}
static DataTypePtr create(const ASTPtr & arguments) static DataTypePtr create(const ASTPtr & arguments)
{ {
if (!arguments || arguments->children.empty()) if (!arguments || arguments->children.empty())

View File

@ -54,6 +54,8 @@ public:
/// Check if Variant has provided type in the list of variants and return its discriminator. /// Check if Variant has provided type in the list of variants and return its discriminator.
std::optional<ColumnVariant::Discriminator> tryGetVariantDiscriminator(const DataTypePtr & type) const; std::optional<ColumnVariant::Discriminator> tryGetVariantDiscriminator(const DataTypePtr & type) const;
void forEachChild(const ChildCallback & callback) const override;
private: private:
std::string doGetName() const override; std::string doGetName() const override;
std::string doGetPrettyName(size_t indent) const override; std::string doGetPrettyName(size_t indent) const override;

View File

@ -111,6 +111,10 @@ public:
const SubcolumnCallback & callback, const SubcolumnCallback & callback,
const SubstreamData & data); const SubstreamData & data);
/// Call callback for each nested type recursively.
using ChildCallback = std::function<void(const IDataType &)>;
virtual void forEachChild(const ChildCallback &) const {}
Names getSubcolumnNames() const; Names getSubcolumnNames() const;
virtual MutableSerializationInfoPtr createSerializationInfo(const SerializationInfo::Settings & settings) const; virtual MutableSerializationInfoPtr createSerializationInfo(const SerializationInfo::Settings & settings) const;

View File

@ -7,11 +7,6 @@
#include <DataTypes/DataTypeLowCardinality.h> #include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeFixedString.h> #include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeVariant.h>
#include <DataTypes/DataTypeMap.h>
namespace DB namespace DB
{ {
@ -24,11 +19,13 @@ namespace ErrorCodes
} }
void validateDataType(const DataTypePtr & type, const DataTypeValidationSettings & settings) void validateDataType(const DataTypePtr & type_to_check, const DataTypeValidationSettings & settings)
{ {
auto validate_callback = [&](const IDataType & data_type)
{
if (!settings.allow_suspicious_low_cardinality_types) if (!settings.allow_suspicious_low_cardinality_types)
{ {
if (const auto * lc_type = typeid_cast<const DataTypeLowCardinality *>(type.get())) if (const auto * lc_type = typeid_cast<const DataTypeLowCardinality *>(&data_type))
{ {
if (!isStringOrFixedString(*removeNullable(lc_type->getDictionaryType()))) if (!isStringOrFixedString(*removeNullable(lc_type->getDictionaryType())))
throw Exception( throw Exception(
@ -41,67 +38,45 @@ void validateDataType(const DataTypePtr & type, const DataTypeValidationSettings
if (!settings.allow_experimental_object_type) if (!settings.allow_experimental_object_type)
{ {
if (type->hasDynamicSubcolumns()) if (data_type.hasDynamicSubcolumns())
{ {
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_COLUMN, ErrorCodes::ILLEGAL_COLUMN,
"Cannot create column with type '{}' because experimental Object type is not allowed. " "Cannot create column with type '{}' because experimental Object type is not allowed. "
"Set setting allow_experimental_object_type = 1 in order to allow it", type->getName()); "Set setting allow_experimental_object_type = 1 in order to allow it",
data_type.getName());
} }
} }
if (!settings.allow_suspicious_fixed_string_types) if (!settings.allow_suspicious_fixed_string_types)
{ {
if (const auto * fixed_string = typeid_cast<const DataTypeFixedString *>(type.get())) if (const auto * fixed_string = typeid_cast<const DataTypeFixedString *>(&data_type))
{ {
if (fixed_string->getN() > MAX_FIXEDSTRING_SIZE_WITHOUT_SUSPICIOUS) if (fixed_string->getN() > MAX_FIXEDSTRING_SIZE_WITHOUT_SUSPICIOUS)
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_COLUMN, ErrorCodes::ILLEGAL_COLUMN,
"Cannot create column with type '{}' because fixed string with size > {} is suspicious. " "Cannot create column with type '{}' because fixed string with size > {} is suspicious. "
"Set setting allow_suspicious_fixed_string_types = 1 in order to allow it", "Set setting allow_suspicious_fixed_string_types = 1 in order to allow it",
type->getName(), data_type.getName(),
MAX_FIXEDSTRING_SIZE_WITHOUT_SUSPICIOUS); MAX_FIXEDSTRING_SIZE_WITHOUT_SUSPICIOUS);
} }
} }
if (!settings.allow_experimental_variant_type) if (!settings.allow_experimental_variant_type)
{ {
if (isVariant(type)) if (isVariant(data_type))
{ {
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_COLUMN, ErrorCodes::ILLEGAL_COLUMN,
"Cannot create column with type '{}' because experimental Variant type is not allowed. " "Cannot create column with type '{}' because experimental Variant type is not allowed. "
"Set setting allow_experimental_variant_type = 1 in order to allow it", type->getName()); "Set setting allow_experimental_variant_type = 1 in order to allow it",
data_type.getName());
} }
} }
};
if (const auto * nullable_type = typeid_cast<const DataTypeNullable *>(type.get())) validate_callback(*type_to_check);
{ type_to_check->forEachChild(validate_callback);
validateDataType(nullable_type->getNestedType(), settings);
}
else if (const auto * lc_type = typeid_cast<const DataTypeLowCardinality *>(type.get()))
{
validateDataType(lc_type->getDictionaryType(), settings);
}
else if (const auto * array_type = typeid_cast<const DataTypeArray *>(type.get()))
{
validateDataType(array_type->getNestedType(), settings);
}
else if (const auto * tuple_type = typeid_cast<const DataTypeTuple *>(type.get()))
{
for (const auto & element : tuple_type->getElements())
validateDataType(element, settings);
}
else if (const auto * map_type = typeid_cast<const DataTypeMap *>(type.get()))
{
validateDataType(map_type->getKeyType(), settings);
validateDataType(map_type->getValueType(), settings);
}
else if (const auto * variant_type = typeid_cast<const DataTypeVariant *>(type.get()))
{
for (const auto & variant : variant_type->getVariants())
validateDataType(variant, settings);
}
} }
ColumnsDescription parseColumnsListFromString(const std::string & structure, const ContextPtr & context) ColumnsDescription parseColumnsListFromString(const std::string & structure, const ContextPtr & context)