From 9d9d16e1ea74652557d29513469dfe9c8bee131b Mon Sep 17 00:00:00 2001 From: bgranvea Date: Fri, 8 Mar 2019 17:49:10 +0100 Subject: [PATCH 1/8] support for SimpleAggregateFunction data type --- .../AggregatingSortedBlockInputStream.cpp | 44 +++++- .../AggregatingSortedBlockInputStream.h | 51 ++++++ .../DataTypes/DataTypeDomainIPv4AndIPv6.cpp | 18 +-- .../DataTypeDomainSimpleAggregateFunction.cpp | 149 ++++++++++++++++++ .../DataTypeDomainSimpleAggregateFunction.h | 45 ++++++ .../DataTypeDomainWithSimpleSerialization.h | 2 +- dbms/src/DataTypes/DataTypeFactory.cpp | 31 ++-- dbms/src/DataTypes/DataTypeFactory.h | 13 +- dbms/src/DataTypes/IDataType.cpp | 53 +++---- dbms/src/DataTypes/IDataType.h | 11 +- dbms/src/DataTypes/IDataTypeDomain.h | 41 +++-- .../00915_simple_aggregate_function.reference | 43 +++++ .../00915_simple_aggregate_function.sql | 27 ++++ 13 files changed, 454 insertions(+), 74 deletions(-) create mode 100644 dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp create mode 100644 dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h create mode 100644 dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference create mode 100644 dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql diff --git a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp index 0697ec8167c..34fb19b2688 100644 --- a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp +++ b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include namespace DB @@ -22,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream( ColumnWithTypeAndName & column = header.safeGetByPosition(i); /// We leave only states of aggregate functions. - if (!startsWith(column.type->getName(), "AggregateFunction")) + if (!dynamic_cast(column.type.get()) && !findSimpleAggregateFunction(column.type)) { column_numbers_not_to_aggregate.push_back(i); continue; @@ -40,7 +42,14 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream( continue; } - column_numbers_to_aggregate.push_back(i); + if (auto simple_aggr = findSimpleAggregateFunction(column.type)) { + // simple aggregate function + SimpleAggregateDescription desc{simple_aggr->getFunction(), i}; + columns_to_simple_aggregate.emplace_back(std::move(desc)); + } else { + // standard aggregate function + column_numbers_to_aggregate.push_back(i); + } } } @@ -90,8 +99,11 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s key_differs = next_key != current_key; /// if there are enough rows accumulated and the last one is calculated completely - if (key_differs && merged_rows >= max_block_size) + if (key_differs && merged_rows >= max_block_size) { + /// Write the simple aggregation result for the previous group. + insertSimpleAggregationResult(merged_columns); return; + } queue.pop(); @@ -110,6 +122,14 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s for (auto & column_to_aggregate : columns_to_aggregate) column_to_aggregate->insertDefault(); + /// Write the simple aggregation result for the previous group. + if (merged_rows > 0) + insertSimpleAggregationResult(merged_columns); + + /// Reset simple aggregation states for next row + for (auto & desc : columns_to_simple_aggregate) + desc.createState(); + ++merged_rows; } @@ -127,6 +147,9 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s } } + /// Write the simple aggregation result for the previous group. + insertSimpleAggregationResult(merged_columns); + finished = true; } @@ -138,6 +161,21 @@ void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor) size_t j = column_numbers_to_aggregate[i]; columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos); } + + for (auto & desc : columns_to_simple_aggregate) + { + auto & col = cursor->all_columns[desc.column_number]; + desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr); + } +} + +void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns) +{ + for (auto & desc : columns_to_simple_aggregate) + { + desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]); + desc.destroyState(); + } } } diff --git a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h index 522b54aeaec..97a579e31a6 100644 --- a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h +++ b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace DB @@ -38,10 +39,13 @@ private: /// Read finished. bool finished = false; + struct SimpleAggregateDescription; + /// Columns with which numbers should be aggregated. ColumnNumbers column_numbers_to_aggregate; ColumnNumbers column_numbers_not_to_aggregate; std::vector columns_to_aggregate; + std::vector columns_to_simple_aggregate; RowRef current_key; /// The current primary key. RowRef next_key; /// The primary key of the next row. @@ -54,6 +58,53 @@ private: /** Extract all states of aggregate functions and merge them with the current group. */ void addRow(SortCursor & cursor); + + /** Insert all values of current row for simple aggregate functions + */ + void insertSimpleAggregationResult(MutableColumns & merged_columns); + + /// Stores information for aggregation of SimpleAggregateFunction columns + struct SimpleAggregateDescription + { + /// An aggregate function 'anyLast', 'sum'... + AggregateFunctionPtr function; + IAggregateFunction::AddFunc add_function; + size_t column_number; + AlignedBuffer state; + bool created = false; + + SimpleAggregateDescription(const AggregateFunctionPtr & function_, const size_t column_number_) : function(function_), column_number(column_number_) + { + add_function = function->getAddressOfAddFunction(); + state.reset(function->sizeOfData(), function->alignOfData()); + } + + void createState() + { + if (created) + return; + function->create(state.data()); + created = true; + } + + void destroyState() + { + if (!created) + return; + function->destroy(state.data()); + created = false; + } + + /// Explicitly destroy aggregation state if the stream is terminated + ~SimpleAggregateDescription() + { + destroyState(); + } + + SimpleAggregateDescription() = default; + SimpleAggregateDescription(SimpleAggregateDescription &&) = default; + SimpleAggregateDescription(const SimpleAggregateDescription &) = delete; + }; }; } diff --git a/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp b/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp index 873dbde506b..f57a6167d3d 100644 --- a/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp +++ b/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp @@ -23,7 +23,7 @@ namespace class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization { public: - const char * getName() const override + String doGetName() const override { return "IPv4"; } @@ -33,7 +33,7 @@ public: const auto col = checkAndGetColumn(&column); if (!col) { - throw Exception(String(getName()) + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -48,7 +48,7 @@ public: ColumnUInt32 * col = typeid_cast(&column); if (!col) { - throw Exception(String(getName()) + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -66,7 +66,7 @@ public: class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization { public: - const char * getName() const override + String doGetName() const override { return "IPv6"; } @@ -76,7 +76,7 @@ public: const auto col = checkAndGetColumn(&column); if (!col) { - throw Exception(String(getName()) + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -91,7 +91,7 @@ public: ColumnFixedString * col = typeid_cast(&column); if (!col) { - throw Exception(String(getName()) + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -100,7 +100,7 @@ public: std::string ipv6_value(IPV6_BINARY_LENGTH, '\0'); if (!parseIPv6(buffer, reinterpret_cast(ipv6_value.data()))) { - throw Exception(String("Invalid ") + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING); + throw Exception("Invalid " + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING); } col->insertString(ipv6_value); @@ -111,8 +111,8 @@ public: void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory) { - factory.registerDataTypeDomain("UInt32", std::make_unique()); - factory.registerDataTypeDomain("FixedString(16)", std::make_unique()); + factory.registerDataTypeDomain("IPv4", [] { return std::make_pair(DataTypeFactory::instance().get("UInt32"), std::make_unique()); }); + factory.registerDataTypeDomain("IPv6", [] { return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"), std::make_unique()); }); } } // namespace DB diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp new file mode 100644 index 00000000000..402ce86ad62 --- /dev/null +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -0,0 +1,149 @@ +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +namespace DB { + +namespace ErrorCodes +{ + extern const int SYNTAX_ERROR; + extern const int BAD_ARGUMENTS; + extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int LOGICAL_ERROR; +} + +const std::vector supported_functions = std::vector( + {"any", "anyLast", "min", "max", "sum"}); + + +String DataTypeDomainSimpleAggregateFunction::doGetName() const { + std::stringstream stream; + stream << "SimpleAggregateFunction(" << function->getName(); + + if (!parameters.empty()) { + stream << "("; + for (size_t i = 0; i < parameters.size(); ++i) { + if (i) + stream << ", "; + stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]); + } + stream << ")"; + } + + for (const auto &argument_type : argument_types) + stream << ", " << argument_type->getName(); + + stream << ")"; + return stream.str(); +} + + +static std::pair create(const ASTPtr & arguments) +{ + String function_name; + AggregateFunctionPtr function; + DataTypes argument_types; + Array params_row; + + if (!arguments || arguments->children.empty()) + throw Exception("Data type SimpleAggregateFunction requires parameters: " + "name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + if (const ASTFunction * parametric = typeid_cast(arguments->children[0].get())) + { + if (parametric->parameters) + throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR); + function_name = parametric->name; + + const ASTs & parameters = typeid_cast(*parametric->arguments).children; + params_row.resize(parameters.size()); + + for (size_t i = 0; i < parameters.size(); ++i) + { + const ASTLiteral * lit = typeid_cast(parameters[i].get()); + if (!lit) + throw Exception("Parameters to aggregate functions must be literals", + ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS); + + params_row[i] = lit->value; + } + } + else if (auto opt_name = getIdentifierName(arguments->children[0])) + { + function_name = *opt_name; + } + else if (typeid_cast(arguments->children[0].get())) + { + throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function", + ErrorCodes::BAD_ARGUMENTS); + } + else + throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function.", + ErrorCodes::BAD_ARGUMENTS); + + for (size_t i = 1; i < arguments->children.size(); ++i) + argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i])); + + if (function_name.empty()) + throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); + + function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); + + // check function + if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) { + throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","), + ErrorCodes::BAD_ARGUMENTS); + } + + DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName()); + DataTypeDomainPtr domain = std::make_unique(storage_type, function, argument_types, params_row); + + if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) { + throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(), + ErrorCodes::BAD_ARGUMENTS); + } + + return std::make_pair(storage_type, std::move(domain)); +} + +static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(const IDataTypeDomain * domain) { + if (domain == nullptr) + return nullptr; + + if (auto simple_aggr = dynamic_cast(domain)) + return simple_aggr; + + if (domain->getDomain() != nullptr) + return findSimpleAggregateFunction(domain->getDomain()); + + return nullptr; +} + +const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType) { + return findSimpleAggregateFunction(dataType->getDomain()); +} + + +void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory) +{ + factory.registerDataTypeDomain("SimpleAggregateFunction", create); +} + +} diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h new file mode 100644 index 00000000000..70e94b1a652 --- /dev/null +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +#include + +namespace DB +{ + +/** The type SimpleAggregateFunction(fct, type) is meant to be used in an AggregatingMergeTree. It behaves like a standard + * data type but when rows are merged, an aggregation function is applied. + * + * The aggregation function is limited to simple functions whose merge state is the final result: + * any, anyLast, min, max, sum + * + * Examples: + * + * SimpleAggregateFunction(sum, Nullable(Float64)) + * SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) + * SimpleAggregateFunction(anyLast, IPv4) + * + * Technically, a standard IDataType is instanciated and a DataTypeDomainSimpleAggregateFunction is added as domain. + */ + +class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain { +private: + const DataTypePtr storage_type; + const AggregateFunctionPtr function; + const DataTypes argument_types; + const Array parameters; + +public: + DataTypeDomainSimpleAggregateFunction(const DataTypePtr storage_type_, const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_) + : storage_type(storage_type_), function(function_), argument_types(argument_types_), parameters(parameters_) {} + + const AggregateFunctionPtr getFunction() const { return function; } + String doGetName() const override; +}; + +/// recursively follow data type domain to find a DataTypeDomainSimpleAggregateFunction +const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType); + +} diff --git a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h b/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h index 7834e9235d2..3ccb4091636 100644 --- a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h +++ b/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h @@ -12,7 +12,7 @@ class IColumn; /** Simple DataTypeDomain that uses serializeText/deserializeText * for all serialization and deserialization. */ -class DataTypeDomainWithSimpleSerialization : public IDataTypeDomain +class DataTypeDomainWithSimpleSerialization : public IDataTypeDomainCustomSerialization { public: virtual ~DataTypeDomainWithSimpleSerialization() override; diff --git a/dbms/src/DataTypes/DataTypeFactory.cpp b/dbms/src/DataTypes/DataTypeFactory.cpp index a0afab890e9..a405075e884 100644 --- a/dbms/src/DataTypes/DataTypeFactory.cpp +++ b/dbms/src/DataTypes/DataTypeFactory.cpp @@ -115,19 +115,23 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator }, case_sensitiveness); } -void DataTypeFactory::registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness) +void DataTypeFactory::registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness) { - all_domains.reserve(all_domains.size() + 1); - - auto data_type = get(type_name); - setDataTypeDomain(*data_type, *domain); - - registerDataType(domain->getName(), [data_type](const ASTPtr & /*ast*/) + registerDataType(family_name, [creator](const ASTPtr & ast) { - return data_type; - }, case_sensitiveness); + auto res = creator(ast); + res.first->appendDomain(std::move(res.second)); - all_domains.emplace_back(std::move(domain)); + return res.first; + }, case_sensitiveness); +} + +void DataTypeFactory::registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness) +{ + registerDataTypeDomain(name, [creator](const ASTPtr & /*ast*/) + { + return creator(); + }, case_sensitiveness); } const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const @@ -153,11 +157,6 @@ const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE); } -void DataTypeFactory::setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain) -{ - data_type.setDomain(&domain); -} - void registerDataTypeNumbers(DataTypeFactory & factory); void registerDataTypeDecimal(DataTypeFactory & factory); void registerDataTypeDate(DataTypeFactory & factory); @@ -175,6 +174,7 @@ void registerDataTypeNested(DataTypeFactory & factory); void registerDataTypeInterval(DataTypeFactory & factory); void registerDataTypeLowCardinality(DataTypeFactory & factory); void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory); +void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory); DataTypeFactory::DataTypeFactory() @@ -196,6 +196,7 @@ DataTypeFactory::DataTypeFactory() registerDataTypeInterval(*this); registerDataTypeLowCardinality(*this); registerDataTypeDomainIPv4AndIPv6(*this); + registerDataTypeDomainSimpleAggregateFunction(*this); } DataTypeFactory::~DataTypeFactory() diff --git a/dbms/src/DataTypes/DataTypeFactory.h b/dbms/src/DataTypes/DataTypeFactory.h index c6ef100bbb7..e4a82b342d1 100644 --- a/dbms/src/DataTypes/DataTypeFactory.h +++ b/dbms/src/DataTypes/DataTypeFactory.h @@ -28,6 +28,8 @@ class DataTypeFactory final : public ext::singleton, public IFa private: using SimpleCreator = std::function; using DataTypesDictionary = std::unordered_map; + using CreatorWithDomain = std::function(const ASTPtr & parameters)>; + using SimpleCreatorWithDomain = std::function()>; public: DataTypePtr get(const String & full_name) const; @@ -40,11 +42,13 @@ public: /// Register a simple data type, that have no parameters. void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); - // Register a domain - a refinement of existing type. - void registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness = CaseSensitive); + /// Register a type family with a dynamic domain + void registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + + /// Register a simple data type domain + void registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive); private: - static void setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain); const Creator& findCreatorByName(const String & family_name) const; private: @@ -53,9 +57,6 @@ private: /// Case insensitive data types will be additionally added here with lowercased name. DataTypesDictionary case_insensitive_data_types; - // All domains are owned by factory and shared amongst DataType instances. - std::vector all_domains; - DataTypeFactory(); ~DataTypeFactory() override; diff --git a/dbms/src/DataTypes/IDataType.cpp b/dbms/src/DataTypes/IDataType.cpp index 679871dba71..0270f1d7923 100644 --- a/dbms/src/DataTypes/IDataType.cpp +++ b/dbms/src/DataTypes/IDataType.cpp @@ -142,9 +142,9 @@ void IDataType::insertDefaultInto(IColumn & column) const void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextEscaped(column, row_num, ostr, settings); + ser_domain->serializeTextEscaped(column, row_num, ostr, settings); } else { @@ -154,9 +154,9 @@ void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, W void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextEscaped(column, istr, settings); + ser_domain->deserializeTextEscaped(column, istr, settings); } else { @@ -166,9 +166,9 @@ void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, co void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextQuoted(column, row_num, ostr, settings); + ser_domain->serializeTextQuoted(column, row_num, ostr, settings); } else { @@ -178,9 +178,9 @@ void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, Wr void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextQuoted(column, istr, settings); + ser_domain->deserializeTextQuoted(column, istr, settings); } else { @@ -190,9 +190,9 @@ void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, con void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) - { - domain->serializeTextCSV(column, row_num, ostr, settings); + if (auto ser_domain = dynamic_cast(domain.get())) + { + ser_domain->serializeTextCSV(column, row_num, ostr, settings); } else { @@ -202,9 +202,9 @@ void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, Write void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextCSV(column, istr, settings); + ser_domain->deserializeTextCSV(column, istr, settings); } else { @@ -214,9 +214,9 @@ void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeText(column, row_num, ostr, settings); + ser_domain->serializeText(column, row_num, ostr, settings); } else { @@ -226,9 +226,9 @@ void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuf void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextJSON(column, row_num, ostr, settings); + ser_domain->serializeTextJSON(column, row_num, ostr, settings); } else { @@ -238,9 +238,9 @@ void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, Writ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextJSON(column, istr, settings); + ser_domain->deserializeTextJSON(column, istr, settings); } else { @@ -250,9 +250,9 @@ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextXML(column, row_num, ostr, settings); + ser_domain->serializeTextXML(column, row_num, ostr, settings); } else { @@ -260,13 +260,12 @@ void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, Write } } -void IDataType::setDomain(const IDataTypeDomain* const new_domain) const +void IDataType::appendDomain(DataTypeDomainPtr new_domain) const { - if (domain != nullptr) - { - throw Exception("Type " + getName() + " already has a domain.", ErrorCodes::LOGICAL_ERROR); - } - domain = new_domain; + if (domain == nullptr) + domain = std::move(new_domain); + else + domain->appendDomain(std::move(new_domain)); } } diff --git a/dbms/src/DataTypes/IDataType.h b/dbms/src/DataTypes/IDataType.h index aa253fbdc08..a95402bf20a 100644 --- a/dbms/src/DataTypes/IDataType.h +++ b/dbms/src/DataTypes/IDataType.h @@ -13,6 +13,8 @@ class ReadBuffer; class WriteBuffer; class IDataTypeDomain; +using DataTypeDomainPtr = std::unique_ptr; + class IDataType; struct FormatSettings; @@ -459,18 +461,19 @@ public: private: friend class DataTypeFactory; - /** Sets domain on existing DataType, can be considered as second phase + /** Sets domain on existing DataType or append it to existing domain, can be considered as second phase * of construction explicitly done by DataTypeFactory. - * Will throw an exception if domain is already set. */ - void setDomain(const IDataTypeDomain* newDomain) const; + void appendDomain(DataTypeDomainPtr new_domain) const; private: /** This is mutable to allow setting domain on `const IDataType` post construction, * simplifying creation of domains for all types, without them even knowing * of domain existence. */ - mutable IDataTypeDomain const* domain; + mutable DataTypeDomainPtr domain; +public: + const IDataTypeDomain * getDomain() const { return domain.get(); } }; diff --git a/dbms/src/DataTypes/IDataTypeDomain.h b/dbms/src/DataTypes/IDataTypeDomain.h index ad38e88a213..1eed8afd808 100644 --- a/dbms/src/DataTypes/IDataTypeDomain.h +++ b/dbms/src/DataTypes/IDataTypeDomain.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include namespace DB { @@ -10,21 +12,42 @@ class WriteBuffer; struct FormatSettings; class IColumn; -/** Further refinment of the properties of data type. - * - * Contains methods for serialization/deserialization. - * Implementations of this interface represent a data type domain (example: IPv4) - * which is a refinement of the exsitgin type with a name and specific text - * representation. - * - * IDataTypeDomain is totally immutable object. You can always share them. +/** Allow to customize an existing data type and set a different name. Derived class IDataTypeDomainCustomSerialization allows + * further customization of serialization/deserialization methods. See use in IPv4 and IPv6 data type domains. + * + * IDataTypeDomain can be chained for further delegation (only for getName for the moment). */ class IDataTypeDomain { +private: + mutable DataTypeDomainPtr delegate; + public: virtual ~IDataTypeDomain() {} - virtual const char* getName() const = 0; + String getName() const { + if (delegate) + return delegate->getName(); + else + return doGetName(); + } + + void appendDomain(DataTypeDomainPtr delegate_) const { + if (delegate == nullptr) + delegate = std::move(delegate_); + else + delegate->appendDomain(std::move(delegate_)); + } + + const IDataTypeDomain * getDomain() const { return delegate.get(); } + +protected: + virtual String doGetName() const = 0; +}; + +class IDataTypeDomainCustomSerialization : public IDataTypeDomain { +public: + virtual ~IDataTypeDomainCustomSerialization() {} /** Text serialization for displaying on a terminal or saving into a text file, and the like. * Without escaping or quoting. diff --git a/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference new file mode 100644 index 00000000000..fbb3d60638e --- /dev/null +++ b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference @@ -0,0 +1,43 @@ +0 0 +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +7 7 +8 8 +9 9 +0 0 +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +7 7 +8 8 +9 9 +SimpleAggregateFunction(sum, Float64) +0 0 +1 2 +2 4 +3 6 +4 8 +5 10 +6 12 +7 14 +8 16 +9 18 +0 0 +1 2 +2 4 +3 6 +4 8 +5 10 +6 12 +7 14 +8 16 +9 18 +1 1 2 2.2.2.2 +SimpleAggregateFunction(anyLast, Nullable(String)) SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) SimpleAggregateFunction(anyLast, IPv4) diff --git a/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql new file mode 100644 index 00000000000..f4f80033eaa --- /dev/null +++ b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql @@ -0,0 +1,27 @@ +-- basic test +drop table if exists test.simple; + +create table test.simple (id UInt64,val SimpleAggregateFunction(sum,Double)) engine=AggregatingMergeTree order by id; +insert into test.simple select number,number from system.numbers limit 10; + +select * from test.simple; +select * from test.simple final; +select toTypeName(val) from test.simple limit 1; + +-- merge +insert into test.simple select number,number from system.numbers limit 10; + +select * from test.simple final; + +optimize table test.simple final; +select * from test.simple; + +-- complex types +drop table if exists test.simple; + +create table test.simple (id UInt64,nullable_str SimpleAggregateFunction(anyLast,Nullable(String)),low_str SimpleAggregateFunction(anyLast,LowCardinality(Nullable(String))),ip SimpleAggregateFunction(anyLast,IPv4)) engine=AggregatingMergeTree order by id; +insert into test.simple values(1,'1','1','1.1.1.1'); +insert into test.simple values(1,null,'2','2.2.2.2'); + +select * from test.simple final; +select toTypeName(nullable_str),toTypeName(low_str),toTypeName(ip) from test.simple limit 1; From faa7d38cb5b8ba1d5fcb97c8442c02eaadc49bd7 Mon Sep 17 00:00:00 2001 From: bgranvea Date: Mon, 11 Mar 2019 09:24:52 +0100 Subject: [PATCH 2/8] fix for style --- .../AggregatingSortedBlockInputStream.cpp | 10 +++++--- .../DataTypeDomainSimpleAggregateFunction.cpp | 24 ++++++++++++------- .../DataTypeDomainSimpleAggregateFunction.h | 3 ++- dbms/src/DataTypes/IDataTypeDomain.h | 9 ++++--- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp index 34fb19b2688..1be85f7e1b8 100644 --- a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp +++ b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp @@ -42,11 +42,14 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream( continue; } - if (auto simple_aggr = findSimpleAggregateFunction(column.type)) { + if (auto simple_aggr = findSimpleAggregateFunction(column.type)) + { // simple aggregate function SimpleAggregateDescription desc{simple_aggr->getFunction(), i}; columns_to_simple_aggregate.emplace_back(std::move(desc)); - } else { + } + else + { // standard aggregate function column_numbers_to_aggregate.push_back(i); } @@ -99,7 +102,8 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s key_differs = next_key != current_key; /// if there are enough rows accumulated and the last one is calculated completely - if (key_differs && merged_rows >= max_block_size) { + if (key_differs && merged_rows >= max_block_size) + { /// Write the simple aggregation result for the previous group. insertSimpleAggregationResult(merged_columns); return; diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp index 402ce86ad62..65bef22ce28 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -18,7 +18,8 @@ #include -namespace DB { +namespace DB +{ namespace ErrorCodes { @@ -33,13 +34,16 @@ const std::vector supported_functions = std::vector( {"any", "anyLast", "min", "max", "sum"}); -String DataTypeDomainSimpleAggregateFunction::doGetName() const { +String DataTypeDomainSimpleAggregateFunction::doGetName() const +{ std::stringstream stream; stream << "SimpleAggregateFunction(" << function->getName(); - if (!parameters.empty()) { + if (!parameters.empty()) + { stream << "("; - for (size_t i = 0; i < parameters.size(); ++i) { + for (size_t i = 0; i < parameters.size(); ++i) + { if (i) stream << ", "; stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]); @@ -107,7 +111,8 @@ static std::pair create(const ASTPtr & arguments function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); // check function - if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) { + if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) + { throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","), ErrorCodes::BAD_ARGUMENTS); } @@ -115,7 +120,8 @@ static std::pair create(const ASTPtr & arguments DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName()); DataTypeDomainPtr domain = std::make_unique(storage_type, function, argument_types, params_row); - if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) { + if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) + { throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(), ErrorCodes::BAD_ARGUMENTS); } @@ -123,7 +129,8 @@ static std::pair create(const ASTPtr & arguments return std::make_pair(storage_type, std::move(domain)); } -static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(const IDataTypeDomain * domain) { +static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(const IDataTypeDomain * domain) +{ if (domain == nullptr) return nullptr; @@ -136,7 +143,8 @@ static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction return nullptr; } -const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType) { +const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType) +{ return findSimpleAggregateFunction(dataType->getDomain()); } diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h index 70e94b1a652..6573f1ae5d0 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h @@ -24,7 +24,8 @@ namespace DB * Technically, a standard IDataType is instanciated and a DataTypeDomainSimpleAggregateFunction is added as domain. */ -class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain { +class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain +{ private: const DataTypePtr storage_type; const AggregateFunctionPtr function; diff --git a/dbms/src/DataTypes/IDataTypeDomain.h b/dbms/src/DataTypes/IDataTypeDomain.h index 1eed8afd808..a840964d28a 100644 --- a/dbms/src/DataTypes/IDataTypeDomain.h +++ b/dbms/src/DataTypes/IDataTypeDomain.h @@ -25,14 +25,16 @@ private: public: virtual ~IDataTypeDomain() {} - String getName() const { + String getName() const + { if (delegate) return delegate->getName(); else return doGetName(); } - void appendDomain(DataTypeDomainPtr delegate_) const { + void appendDomain(DataTypeDomainPtr delegate_) const + { if (delegate == nullptr) delegate = std::move(delegate_); else @@ -45,7 +47,8 @@ protected: virtual String doGetName() const = 0; }; -class IDataTypeDomainCustomSerialization : public IDataTypeDomain { +class IDataTypeDomainCustomSerialization : public IDataTypeDomain +{ public: virtual ~IDataTypeDomainCustomSerialization() {} From ee5a88c15f979dd8782a2a82ce1ad5f434f57ad3 Mon Sep 17 00:00:00 2001 From: bgranvea Date: Mon, 11 Mar 2019 11:46:14 +0100 Subject: [PATCH 3/8] fix memory leak --- dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp | 2 +- dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp index 65bef22ce28..5daa886df43 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -118,7 +118,7 @@ static std::pair create(const ASTPtr & arguments } DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName()); - DataTypeDomainPtr domain = std::make_unique(storage_type, function, argument_types, params_row); + DataTypeDomainPtr domain = std::make_unique(function, argument_types, params_row); if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) { diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h index 6573f1ae5d0..98989eeac11 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h @@ -27,14 +27,13 @@ namespace DB class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain { private: - const DataTypePtr storage_type; const AggregateFunctionPtr function; const DataTypes argument_types; const Array parameters; public: - DataTypeDomainSimpleAggregateFunction(const DataTypePtr storage_type_, const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_) - : storage_type(storage_type_), function(function_), argument_types(argument_types_), parameters(parameters_) {} + DataTypeDomainSimpleAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_) + : function(function_), argument_types(argument_types_), parameters(parameters_) {} const AggregateFunctionPtr getFunction() const { return function; } String doGetName() const override; From a3020f2d22d7bd8cf180afd2ae79ab250c2f9ccf Mon Sep 17 00:00:00 2001 From: alexey-milovidov Date: Tue, 26 Mar 2019 00:51:54 +0300 Subject: [PATCH 4/8] Update DataTypeDomainSimpleAggregateFunction.cpp --- dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp index 5daa886df43..570310c1312 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -51,7 +51,7 @@ String DataTypeDomainSimpleAggregateFunction::doGetName() const stream << ")"; } - for (const auto &argument_type : argument_types) + for (const auto & argument_type : argument_types) stream << ", " << argument_type->getName(); stream << ")"; From c1ea15f0bb04c5b722947d6c688e3a0871d07493 Mon Sep 17 00:00:00 2001 From: alexey-milovidov Date: Tue, 26 Mar 2019 00:54:19 +0300 Subject: [PATCH 5/8] Update DataTypeDomainSimpleAggregateFunction.cpp --- .../DataTypes/DataTypeDomainSimpleAggregateFunction.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp index 570310c1312..ee524f76ec9 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -70,18 +70,18 @@ static std::pair create(const ASTPtr & arguments throw Exception("Data type SimpleAggregateFunction requires parameters: " "name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - if (const ASTFunction * parametric = typeid_cast(arguments->children[0].get())) + if (const ASTFunction * parametric = arguments->children[0]->as()) { if (parametric->parameters) throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR); function_name = parametric->name; - const ASTs & parameters = typeid_cast(*parametric->arguments).children; + const ASTs & parameters = parametric->arguments->as().children; params_row.resize(parameters.size()); for (size_t i = 0; i < parameters.size(); ++i) { - const ASTLiteral * lit = typeid_cast(parameters[i].get()); + const ASTLiteral * lit = parameters[i]->as(); if (!lit) throw Exception("Parameters to aggregate functions must be literals", ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS); @@ -93,7 +93,7 @@ static std::pair create(const ASTPtr & arguments { function_name = *opt_name; } - else if (typeid_cast(arguments->children[0].get())) + else if (arguments->children[0]->as()) { throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function", ErrorCodes::BAD_ARGUMENTS); From e4b93f092b6328a097bda6fba9a0d9805d14d0a7 Mon Sep 17 00:00:00 2001 From: alexey-milovidov Date: Tue, 26 Mar 2019 00:57:34 +0300 Subject: [PATCH 6/8] Update DataTypeDomainSimpleAggregateFunction.cpp --- dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp index ee524f76ec9..dfcd7604aeb 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -30,8 +30,8 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -const std::vector supported_functions = std::vector( - {"any", "anyLast", "min", "max", "sum"}); +static const std::initializer_list supported_functions = std::vector( + {"any", "anyLast", "min", "max", "sum"}); String DataTypeDomainSimpleAggregateFunction::doGetName() const From caa096a3d0e7340d5cf0d4cecf33f552c06bbd9d Mon Sep 17 00:00:00 2001 From: alexey-milovidov Date: Tue, 26 Mar 2019 00:57:58 +0300 Subject: [PATCH 7/8] Update DataTypeDomainSimpleAggregateFunction.cpp --- dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp index dfcd7604aeb..82e3b873f5e 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -30,8 +30,7 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -static const std::initializer_list supported_functions = std::vector( - {"any", "anyLast", "min", "max", "sum"}); +static const std::vector supported_functions{"any", "anyLast", "min", "max", "sum"}; String DataTypeDomainSimpleAggregateFunction::doGetName() const From 42b07c5ee980d51b9780aadd30821a174ec9d990 Mon Sep 17 00:00:00 2001 From: bgranvea Date: Fri, 29 Mar 2019 21:04:04 +0100 Subject: [PATCH 8/8] refactor to avoid dynamic_cast in data type serialization --- .../AggregatingSortedBlockInputStream.cpp | 6 +- .../{IDataTypeDomain.h => DataTypeCustom.h} | 67 ++++++++++--------- ...IPv6.cpp => DataTypeCustomIPv4AndIPv6.cpp} | 40 +++++------ ...DataTypeCustomSimpleAggregateFunction.cpp} | 33 ++------- ...> DataTypeCustomSimpleAggregateFunction.h} | 13 ++-- ...DataTypeCustomSimpleTextSerialization.cpp} | 26 +++---- ...> DataTypeCustomSimpleTextSerialization.h} | 8 +-- dbms/src/DataTypes/DataTypeFactory.cpp | 10 +-- dbms/src/DataTypes/DataTypeFactory.h | 15 ++--- dbms/src/DataTypes/IDataType.cpp | 63 ++++++++--------- dbms/src/DataTypes/IDataType.h | 20 +++--- 11 files changed, 139 insertions(+), 162 deletions(-) rename dbms/src/DataTypes/{IDataTypeDomain.h => DataTypeCustom.h} (62%) rename dbms/src/DataTypes/{DataTypeDomainIPv4AndIPv6.cpp => DataTypeCustomIPv4AndIPv6.cpp} (62%) rename dbms/src/DataTypes/{DataTypeDomainSimpleAggregateFunction.cpp => DataTypeCustomSimpleAggregateFunction.cpp} (81%) rename dbms/src/DataTypes/{DataTypeDomainSimpleAggregateFunction.h => DataTypeCustomSimpleAggregateFunction.h} (67%) rename dbms/src/DataTypes/{DataTypeDomainWithSimpleSerialization.cpp => DataTypeCustomSimpleTextSerialization.cpp} (73%) rename dbms/src/DataTypes/{DataTypeDomainWithSimpleSerialization.h => DataTypeCustomSimpleTextSerialization.h} (89%) diff --git a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp index 1be85f7e1b8..f093e47e640 100644 --- a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp +++ b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp @@ -2,7 +2,7 @@ #include #include #include -#include +#include namespace DB @@ -24,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream( ColumnWithTypeAndName & column = header.safeGetByPosition(i); /// We leave only states of aggregate functions. - if (!dynamic_cast(column.type.get()) && !findSimpleAggregateFunction(column.type)) + if (!dynamic_cast(column.type.get()) && !dynamic_cast(column.type->getCustomName())) { column_numbers_not_to_aggregate.push_back(i); continue; @@ -42,7 +42,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream( continue; } - if (auto simple_aggr = findSimpleAggregateFunction(column.type)) + if (auto simple_aggr = dynamic_cast(column.type->getCustomName())) { // simple aggregate function SimpleAggregateDescription desc{simple_aggr->getFunction(), i}; diff --git a/dbms/src/DataTypes/IDataTypeDomain.h b/dbms/src/DataTypes/DataTypeCustom.h similarity index 62% rename from dbms/src/DataTypes/IDataTypeDomain.h rename to dbms/src/DataTypes/DataTypeCustom.h index a840964d28a..93882361e20 100644 --- a/dbms/src/DataTypes/IDataTypeDomain.h +++ b/dbms/src/DataTypes/DataTypeCustom.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include namespace DB { @@ -12,45 +12,21 @@ class WriteBuffer; struct FormatSettings; class IColumn; -/** Allow to customize an existing data type and set a different name. Derived class IDataTypeDomainCustomSerialization allows - * further customization of serialization/deserialization methods. See use in IPv4 and IPv6 data type domains. - * - * IDataTypeDomain can be chained for further delegation (only for getName for the moment). +/** Allow to customize an existing data type and set a different name and/or text serialization/deserialization methods. + * See use in IPv4 and IPv6 data types, and also in SimpleAggregateFunction. */ -class IDataTypeDomain +class IDataTypeCustomName { -private: - mutable DataTypeDomainPtr delegate; - public: - virtual ~IDataTypeDomain() {} + virtual ~IDataTypeCustomName() {} - String getName() const - { - if (delegate) - return delegate->getName(); - else - return doGetName(); - } - - void appendDomain(DataTypeDomainPtr delegate_) const - { - if (delegate == nullptr) - delegate = std::move(delegate_); - else - delegate->appendDomain(std::move(delegate_)); - } - - const IDataTypeDomain * getDomain() const { return delegate.get(); } - -protected: - virtual String doGetName() const = 0; + virtual String getName() const = 0; }; -class IDataTypeDomainCustomSerialization : public IDataTypeDomain +class IDataTypeCustomTextSerialization { public: - virtual ~IDataTypeDomainCustomSerialization() {} + virtual ~IDataTypeCustomTextSerialization() {} /** Text serialization for displaying on a terminal or saving into a text file, and the like. * Without escaping or quoting. @@ -82,4 +58,31 @@ public: virtual void serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const = 0; }; +using DataTypeCustomNamePtr = std::unique_ptr; +using DataTypeCustomTextSerializationPtr = std::unique_ptr; + +/** Describe a data type customization + */ +struct DataTypeCustomDesc +{ + DataTypeCustomNamePtr name; + DataTypeCustomTextSerializationPtr text_serialization; + + DataTypeCustomDesc(DataTypeCustomNamePtr name_, DataTypeCustomTextSerializationPtr text_serialization_) + : name(std::move(name_)), text_serialization(std::move(text_serialization_)) {} +}; + +using DataTypeCustomDescPtr = std::unique_ptr; + +/** A simple implementation of IDataTypeCustomName + */ +class DataTypeCustomFixedName : public IDataTypeCustomName +{ +private: + String name; +public: + DataTypeCustomFixedName(String name_) : name(name_) {} + String getName() const override { return name; } +}; + } // namespace DB diff --git a/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp b/dbms/src/DataTypes/DataTypeCustomIPv4AndIPv6.cpp similarity index 62% rename from dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp rename to dbms/src/DataTypes/DataTypeCustomIPv4AndIPv6.cpp index f57a6167d3d..8d12a9847db 100644 --- a/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp +++ b/dbms/src/DataTypes/DataTypeCustomIPv4AndIPv6.cpp @@ -1,9 +1,9 @@ #include #include #include -#include +#include #include -#include +#include #include #include @@ -20,20 +20,15 @@ namespace ErrorCodes namespace { -class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization +class DataTypeCustomIPv4Serialization : public DataTypeCustomSimpleTextSerialization { public: - String doGetName() const override - { - return "IPv4"; - } - void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override { const auto col = checkAndGetColumn(&column); if (!col) { - throw Exception(getName() + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception("IPv4 type can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -48,7 +43,7 @@ public: ColumnUInt32 * col = typeid_cast(&column); if (!col) { - throw Exception(getName() + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception("IPv4 type can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -63,20 +58,16 @@ public: } }; -class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization +class DataTypeCustomIPv6Serialization : public DataTypeCustomSimpleTextSerialization { public: - String doGetName() const override - { - return "IPv6"; - } void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override { const auto col = checkAndGetColumn(&column); if (!col) { - throw Exception(getName() + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception("IPv6 type domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -91,7 +82,7 @@ public: ColumnFixedString * col = typeid_cast(&column); if (!col) { - throw Exception(getName() + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception("IPv6 type domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -100,7 +91,7 @@ public: std::string ipv6_value(IPV6_BINARY_LENGTH, '\0'); if (!parseIPv6(buffer, reinterpret_cast(ipv6_value.data()))) { - throw Exception("Invalid " + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING); + throw Exception("Invalid IPv6 value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING); } col->insertString(ipv6_value); @@ -111,8 +102,17 @@ public: void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory) { - factory.registerDataTypeDomain("IPv4", [] { return std::make_pair(DataTypeFactory::instance().get("UInt32"), std::make_unique()); }); - factory.registerDataTypeDomain("IPv6", [] { return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"), std::make_unique()); }); + factory.registerSimpleDataTypeCustom("IPv4", [] + { + return std::make_pair(DataTypeFactory::instance().get("UInt32"), + std::make_unique(std::make_unique("IPv4"), std::make_unique())); + }); + + factory.registerSimpleDataTypeCustom("IPv6", [] + { + return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"), + std::make_unique(std::make_unique("IPv6"), std::make_unique())); + }); } } // namespace DB diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp similarity index 81% rename from dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp rename to dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp index 82e3b873f5e..2cb0f87facd 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp @@ -5,7 +5,7 @@ #include -#include +#include #include #include #include @@ -33,7 +33,7 @@ namespace ErrorCodes static const std::vector supported_functions{"any", "anyLast", "min", "max", "sum"}; -String DataTypeDomainSimpleAggregateFunction::doGetName() const +String DataTypeCustomSimpleAggregateFunction::getName() const { std::stringstream stream; stream << "SimpleAggregateFunction(" << function->getName(); @@ -58,7 +58,7 @@ String DataTypeDomainSimpleAggregateFunction::doGetName() const } -static std::pair create(const ASTPtr & arguments) +static std::pair create(const ASTPtr & arguments) { String function_name; AggregateFunctionPtr function; @@ -117,7 +117,6 @@ static std::pair create(const ASTPtr & arguments } DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName()); - DataTypeDomainPtr domain = std::make_unique(function, argument_types, params_row); if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) { @@ -125,32 +124,14 @@ static std::pair create(const ASTPtr & arguments ErrorCodes::BAD_ARGUMENTS); } - return std::make_pair(storage_type, std::move(domain)); + DataTypeCustomNamePtr custom_name = std::make_unique(function, argument_types, params_row); + + return std::make_pair(storage_type, std::make_unique(std::move(custom_name), nullptr)); } -static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(const IDataTypeDomain * domain) -{ - if (domain == nullptr) - return nullptr; - - if (auto simple_aggr = dynamic_cast(domain)) - return simple_aggr; - - if (domain->getDomain() != nullptr) - return findSimpleAggregateFunction(domain->getDomain()); - - return nullptr; -} - -const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType) -{ - return findSimpleAggregateFunction(dataType->getDomain()); -} - - void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory) { - factory.registerDataTypeDomain("SimpleAggregateFunction", create); + factory.registerDataTypeCustom("SimpleAggregateFunction", create); } } diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h b/dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.h similarity index 67% rename from dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h rename to dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.h index 98989eeac11..3e82b546903 100644 --- a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h +++ b/dbms/src/DataTypes/DataTypeCustomSimpleAggregateFunction.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -21,10 +21,10 @@ namespace DB * SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) * SimpleAggregateFunction(anyLast, IPv4) * - * Technically, a standard IDataType is instanciated and a DataTypeDomainSimpleAggregateFunction is added as domain. + * Technically, a standard IDataType is instanciated and customized with IDataTypeCustomName and DataTypeCustomDesc. */ -class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain +class DataTypeCustomSimpleAggregateFunction : public IDataTypeCustomName { private: const AggregateFunctionPtr function; @@ -32,14 +32,11 @@ private: const Array parameters; public: - DataTypeDomainSimpleAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_) + DataTypeCustomSimpleAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_) : function(function_), argument_types(argument_types_), parameters(parameters_) {} const AggregateFunctionPtr getFunction() const { return function; } - String doGetName() const override; + String getName() const override; }; -/// recursively follow data type domain to find a DataTypeDomainSimpleAggregateFunction -const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType); - } diff --git a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.cpp b/dbms/src/DataTypes/DataTypeCustomSimpleTextSerialization.cpp similarity index 73% rename from dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.cpp rename to dbms/src/DataTypes/DataTypeCustomSimpleTextSerialization.cpp index 12b1837be1f..44ce27a6e88 100644 --- a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.cpp +++ b/dbms/src/DataTypes/DataTypeCustomSimpleTextSerialization.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -9,7 +9,7 @@ namespace { using namespace DB; -static String serializeToString(const DataTypeDomainWithSimpleSerialization & domain, const IColumn & column, size_t row_num, const FormatSettings & settings) +static String serializeToString(const DataTypeCustomSimpleTextSerialization & domain, const IColumn & column, size_t row_num, const FormatSettings & settings) { WriteBufferFromOwnString buffer; domain.serializeText(column, row_num, buffer, settings); @@ -17,7 +17,7 @@ static String serializeToString(const DataTypeDomainWithSimpleSerialization & do return buffer.str(); } -static void deserializeFromString(const DataTypeDomainWithSimpleSerialization & domain, IColumn & column, const String & s, const FormatSettings & settings) +static void deserializeFromString(const DataTypeCustomSimpleTextSerialization & domain, IColumn & column, const String & s, const FormatSettings & settings) { ReadBufferFromString istr(s); domain.deserializeText(column, istr, settings); @@ -28,59 +28,59 @@ static void deserializeFromString(const DataTypeDomainWithSimpleSerialization & namespace DB { -DataTypeDomainWithSimpleSerialization::~DataTypeDomainWithSimpleSerialization() +DataTypeCustomSimpleTextSerialization::~DataTypeCustomSimpleTextSerialization() { } -void DataTypeDomainWithSimpleSerialization::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { writeEscapedString(serializeToString(*this, column, row_num, settings), ostr); } -void DataTypeDomainWithSimpleSerialization::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { String str; readEscapedString(str, istr); deserializeFromString(*this, column, str, settings); } -void DataTypeDomainWithSimpleSerialization::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { writeQuotedString(serializeToString(*this, column, row_num, settings), ostr); } -void DataTypeDomainWithSimpleSerialization::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { String str; readQuotedString(str, istr); deserializeFromString(*this, column, str, settings); } -void DataTypeDomainWithSimpleSerialization::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { writeCSVString(serializeToString(*this, column, row_num, settings), ostr); } -void DataTypeDomainWithSimpleSerialization::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { String str; readCSVString(str, istr, settings.csv); deserializeFromString(*this, column, str, settings); } -void DataTypeDomainWithSimpleSerialization::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { writeJSONString(serializeToString(*this, column, row_num, settings), ostr, settings); } -void DataTypeDomainWithSimpleSerialization::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { String str; readJSONString(str, istr); deserializeFromString(*this, column, str, settings); } -void DataTypeDomainWithSimpleSerialization::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const +void DataTypeCustomSimpleTextSerialization::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { writeXMLString(serializeToString(*this, column, row_num, settings), ostr); } diff --git a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h b/dbms/src/DataTypes/DataTypeCustomSimpleTextSerialization.h similarity index 89% rename from dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h rename to dbms/src/DataTypes/DataTypeCustomSimpleTextSerialization.h index 3ccb4091636..fb9be86d95f 100644 --- a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h +++ b/dbms/src/DataTypes/DataTypeCustomSimpleTextSerialization.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace DB { @@ -10,12 +10,12 @@ class WriteBuffer; struct FormatSettings; class IColumn; -/** Simple DataTypeDomain that uses serializeText/deserializeText +/** Simple IDataTypeCustomTextSerialization that uses serializeText/deserializeText * for all serialization and deserialization. */ -class DataTypeDomainWithSimpleSerialization : public IDataTypeDomainCustomSerialization +class DataTypeCustomSimpleTextSerialization : public IDataTypeCustomTextSerialization { public: - virtual ~DataTypeDomainWithSimpleSerialization() override; + virtual ~DataTypeCustomSimpleTextSerialization() override; // Methods that subclasses must override in order to get full serialization/deserialization support. virtual void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override = 0; diff --git a/dbms/src/DataTypes/DataTypeFactory.cpp b/dbms/src/DataTypes/DataTypeFactory.cpp index a405075e884..8c4c899516a 100644 --- a/dbms/src/DataTypes/DataTypeFactory.cpp +++ b/dbms/src/DataTypes/DataTypeFactory.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -115,20 +115,20 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator }, case_sensitiveness); } -void DataTypeFactory::registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness) +void DataTypeFactory::registerDataTypeCustom(const String & family_name, CreatorWithCustom creator, CaseSensitiveness case_sensitiveness) { registerDataType(family_name, [creator](const ASTPtr & ast) { auto res = creator(ast); - res.first->appendDomain(std::move(res.second)); + res.first->setCustomization(std::move(res.second)); return res.first; }, case_sensitiveness); } -void DataTypeFactory::registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness) +void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness) { - registerDataTypeDomain(name, [creator](const ASTPtr & /*ast*/) + registerDataTypeCustom(name, [creator](const ASTPtr & /*ast*/) { return creator(); }, case_sensitiveness); diff --git a/dbms/src/DataTypes/DataTypeFactory.h b/dbms/src/DataTypes/DataTypeFactory.h index e4a82b342d1..a6c714c1a0e 100644 --- a/dbms/src/DataTypes/DataTypeFactory.h +++ b/dbms/src/DataTypes/DataTypeFactory.h @@ -17,9 +17,6 @@ namespace DB class IDataType; using DataTypePtr = std::shared_ptr; -class IDataTypeDomain; -using DataTypeDomainPtr = std::unique_ptr; - /** Creates a data type by name of data type family and parameters. */ @@ -28,8 +25,8 @@ class DataTypeFactory final : public ext::singleton, public IFa private: using SimpleCreator = std::function; using DataTypesDictionary = std::unordered_map; - using CreatorWithDomain = std::function(const ASTPtr & parameters)>; - using SimpleCreatorWithDomain = std::function()>; + using CreatorWithCustom = std::function(const ASTPtr & parameters)>; + using SimpleCreatorWithCustom = std::function()>; public: DataTypePtr get(const String & full_name) const; @@ -42,11 +39,11 @@ public: /// Register a simple data type, that have no parameters. void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); - /// Register a type family with a dynamic domain - void registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + /// Register a customized type family + void registerDataTypeCustom(const String & family_name, CreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive); - /// Register a simple data type domain - void registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + /// Register a simple customized data type + void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive); private: const Creator& findCreatorByName(const String & family_name) const; diff --git a/dbms/src/DataTypes/IDataType.cpp b/dbms/src/DataTypes/IDataType.cpp index 0270f1d7923..09c080f56cc 100644 --- a/dbms/src/DataTypes/IDataType.cpp +++ b/dbms/src/DataTypes/IDataType.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include @@ -23,8 +23,7 @@ namespace ErrorCodes extern const int DATA_TYPE_CANNOT_BE_PROMOTED; } -IDataType::IDataType() - : domain(nullptr) +IDataType::IDataType() : custom_name(nullptr), custom_text_serialization(nullptr) { } @@ -34,9 +33,9 @@ IDataType::~IDataType() String IDataType::getName() const { - if (domain) + if (custom_name) { - return domain->getName(); + return custom_name->getName(); } else { @@ -142,9 +141,9 @@ void IDataType::insertDefaultInto(IColumn & column) const void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->serializeTextEscaped(column, row_num, ostr, settings); + custom_text_serialization->serializeTextEscaped(column, row_num, ostr, settings); } else { @@ -154,9 +153,9 @@ void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, W void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->deserializeTextEscaped(column, istr, settings); + custom_text_serialization->deserializeTextEscaped(column, istr, settings); } else { @@ -166,9 +165,9 @@ void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, co void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->serializeTextQuoted(column, row_num, ostr, settings); + custom_text_serialization->serializeTextQuoted(column, row_num, ostr, settings); } else { @@ -178,9 +177,9 @@ void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, Wr void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->deserializeTextQuoted(column, istr, settings); + custom_text_serialization->deserializeTextQuoted(column, istr, settings); } else { @@ -190,9 +189,9 @@ void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, con void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) - { - ser_domain->serializeTextCSV(column, row_num, ostr, settings); + if (custom_text_serialization) + { + custom_text_serialization->serializeTextCSV(column, row_num, ostr, settings); } else { @@ -202,9 +201,9 @@ void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, Write void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->deserializeTextCSV(column, istr, settings); + custom_text_serialization->deserializeTextCSV(column, istr, settings); } else { @@ -214,9 +213,9 @@ void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->serializeText(column, row_num, ostr, settings); + custom_text_serialization->serializeText(column, row_num, ostr, settings); } else { @@ -226,9 +225,9 @@ void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuf void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->serializeTextJSON(column, row_num, ostr, settings); + custom_text_serialization->serializeTextJSON(column, row_num, ostr, settings); } else { @@ -238,9 +237,9 @@ void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, Writ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->deserializeTextJSON(column, istr, settings); + custom_text_serialization->deserializeTextJSON(column, istr, settings); } else { @@ -250,9 +249,9 @@ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (auto ser_domain = dynamic_cast(domain.get())) + if (custom_text_serialization) { - ser_domain->serializeTextXML(column, row_num, ostr, settings); + custom_text_serialization->serializeTextXML(column, row_num, ostr, settings); } else { @@ -260,12 +259,14 @@ void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, Write } } -void IDataType::appendDomain(DataTypeDomainPtr new_domain) const +void IDataType::setCustomization(DataTypeCustomDescPtr custom_desc_) const { - if (domain == nullptr) - domain = std::move(new_domain); - else - domain->appendDomain(std::move(new_domain)); + /// replace only if not null + if (custom_desc_->name) + custom_name = std::move(custom_desc_->name); + + if (custom_desc_->text_serialization) + custom_text_serialization = std::move(custom_desc_->text_serialization); } } diff --git a/dbms/src/DataTypes/IDataType.h b/dbms/src/DataTypes/IDataType.h index a95402bf20a..6446f8ada43 100644 --- a/dbms/src/DataTypes/IDataType.h +++ b/dbms/src/DataTypes/IDataType.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace DB @@ -12,9 +13,6 @@ namespace DB class ReadBuffer; class WriteBuffer; -class IDataTypeDomain; -using DataTypeDomainPtr = std::unique_ptr; - class IDataType; struct FormatSettings; @@ -461,19 +459,19 @@ public: private: friend class DataTypeFactory; - /** Sets domain on existing DataType or append it to existing domain, can be considered as second phase - * of construction explicitly done by DataTypeFactory. + /** Customize this DataType */ - void appendDomain(DataTypeDomainPtr new_domain) const; + void setCustomization(DataTypeCustomDescPtr custom_desc_) const; private: - /** This is mutable to allow setting domain on `const IDataType` post construction, - * simplifying creation of domains for all types, without them even knowing - * of domain existence. + /** This is mutable to allow setting custom name and serialization on `const IDataType` post construction. */ - mutable DataTypeDomainPtr domain; + mutable DataTypeCustomNamePtr custom_name; + mutable DataTypeCustomTextSerializationPtr custom_text_serialization; + public: - const IDataTypeDomain * getDomain() const { return domain.get(); } + const IDataTypeCustomName * getCustomName() const { return custom_name.get(); } + const IDataTypeCustomTextSerialization * getCustomTextSerialization() const { return custom_text_serialization.get(); } };