From 9d9d16e1ea74652557d29513469dfe9c8bee131b Mon Sep 17 00:00:00 2001 From: bgranvea Date: Fri, 8 Mar 2019 17:49:10 +0100 Subject: [PATCH] support for SimpleAggregateFunction data type --- .../AggregatingSortedBlockInputStream.cpp | 44 +++++- .../AggregatingSortedBlockInputStream.h | 51 ++++++ .../DataTypes/DataTypeDomainIPv4AndIPv6.cpp | 18 +-- .../DataTypeDomainSimpleAggregateFunction.cpp | 149 ++++++++++++++++++ .../DataTypeDomainSimpleAggregateFunction.h | 45 ++++++ .../DataTypeDomainWithSimpleSerialization.h | 2 +- dbms/src/DataTypes/DataTypeFactory.cpp | 31 ++-- dbms/src/DataTypes/DataTypeFactory.h | 13 +- dbms/src/DataTypes/IDataType.cpp | 53 +++---- dbms/src/DataTypes/IDataType.h | 11 +- dbms/src/DataTypes/IDataTypeDomain.h | 41 +++-- .../00915_simple_aggregate_function.reference | 43 +++++ .../00915_simple_aggregate_function.sql | 27 ++++ 13 files changed, 454 insertions(+), 74 deletions(-) create mode 100644 dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp create mode 100644 dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h create mode 100644 dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference create mode 100644 dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql diff --git a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp index 0697ec8167c..34fb19b2688 100644 --- a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp +++ b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include namespace DB @@ -22,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream( ColumnWithTypeAndName & column = header.safeGetByPosition(i); /// We leave only states of aggregate functions. - if (!startsWith(column.type->getName(), "AggregateFunction")) + if (!dynamic_cast(column.type.get()) && !findSimpleAggregateFunction(column.type)) { column_numbers_not_to_aggregate.push_back(i); continue; @@ -40,7 +42,14 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream( continue; } - column_numbers_to_aggregate.push_back(i); + if (auto simple_aggr = findSimpleAggregateFunction(column.type)) { + // simple aggregate function + SimpleAggregateDescription desc{simple_aggr->getFunction(), i}; + columns_to_simple_aggregate.emplace_back(std::move(desc)); + } else { + // standard aggregate function + column_numbers_to_aggregate.push_back(i); + } } } @@ -90,8 +99,11 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s key_differs = next_key != current_key; /// if there are enough rows accumulated and the last one is calculated completely - if (key_differs && merged_rows >= max_block_size) + if (key_differs && merged_rows >= max_block_size) { + /// Write the simple aggregation result for the previous group. + insertSimpleAggregationResult(merged_columns); return; + } queue.pop(); @@ -110,6 +122,14 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s for (auto & column_to_aggregate : columns_to_aggregate) column_to_aggregate->insertDefault(); + /// Write the simple aggregation result for the previous group. + if (merged_rows > 0) + insertSimpleAggregationResult(merged_columns); + + /// Reset simple aggregation states for next row + for (auto & desc : columns_to_simple_aggregate) + desc.createState(); + ++merged_rows; } @@ -127,6 +147,9 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s } } + /// Write the simple aggregation result for the previous group. + insertSimpleAggregationResult(merged_columns); + finished = true; } @@ -138,6 +161,21 @@ void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor) size_t j = column_numbers_to_aggregate[i]; columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos); } + + for (auto & desc : columns_to_simple_aggregate) + { + auto & col = cursor->all_columns[desc.column_number]; + desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr); + } +} + +void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns) +{ + for (auto & desc : columns_to_simple_aggregate) + { + desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]); + desc.destroyState(); + } } } diff --git a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h index 522b54aeaec..97a579e31a6 100644 --- a/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h +++ b/dbms/src/DataStreams/AggregatingSortedBlockInputStream.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace DB @@ -38,10 +39,13 @@ private: /// Read finished. bool finished = false; + struct SimpleAggregateDescription; + /// Columns with which numbers should be aggregated. ColumnNumbers column_numbers_to_aggregate; ColumnNumbers column_numbers_not_to_aggregate; std::vector columns_to_aggregate; + std::vector columns_to_simple_aggregate; RowRef current_key; /// The current primary key. RowRef next_key; /// The primary key of the next row. @@ -54,6 +58,53 @@ private: /** Extract all states of aggregate functions and merge them with the current group. */ void addRow(SortCursor & cursor); + + /** Insert all values of current row for simple aggregate functions + */ + void insertSimpleAggregationResult(MutableColumns & merged_columns); + + /// Stores information for aggregation of SimpleAggregateFunction columns + struct SimpleAggregateDescription + { + /// An aggregate function 'anyLast', 'sum'... + AggregateFunctionPtr function; + IAggregateFunction::AddFunc add_function; + size_t column_number; + AlignedBuffer state; + bool created = false; + + SimpleAggregateDescription(const AggregateFunctionPtr & function_, const size_t column_number_) : function(function_), column_number(column_number_) + { + add_function = function->getAddressOfAddFunction(); + state.reset(function->sizeOfData(), function->alignOfData()); + } + + void createState() + { + if (created) + return; + function->create(state.data()); + created = true; + } + + void destroyState() + { + if (!created) + return; + function->destroy(state.data()); + created = false; + } + + /// Explicitly destroy aggregation state if the stream is terminated + ~SimpleAggregateDescription() + { + destroyState(); + } + + SimpleAggregateDescription() = default; + SimpleAggregateDescription(SimpleAggregateDescription &&) = default; + SimpleAggregateDescription(const SimpleAggregateDescription &) = delete; + }; }; } diff --git a/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp b/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp index 873dbde506b..f57a6167d3d 100644 --- a/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp +++ b/dbms/src/DataTypes/DataTypeDomainIPv4AndIPv6.cpp @@ -23,7 +23,7 @@ namespace class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization { public: - const char * getName() const override + String doGetName() const override { return "IPv4"; } @@ -33,7 +33,7 @@ public: const auto col = checkAndGetColumn(&column); if (!col) { - throw Exception(String(getName()) + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -48,7 +48,7 @@ public: ColumnUInt32 * col = typeid_cast(&column); if (!col) { - throw Exception(String(getName()) + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -66,7 +66,7 @@ public: class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization { public: - const char * getName() const override + String doGetName() const override { return "IPv6"; } @@ -76,7 +76,7 @@ public: const auto col = checkAndGetColumn(&column); if (!col) { - throw Exception(String(getName()) + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -91,7 +91,7 @@ public: ColumnFixedString * col = typeid_cast(&column); if (!col) { - throw Exception(String(getName()) + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(getName() + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN); } char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'}; @@ -100,7 +100,7 @@ public: std::string ipv6_value(IPV6_BINARY_LENGTH, '\0'); if (!parseIPv6(buffer, reinterpret_cast(ipv6_value.data()))) { - throw Exception(String("Invalid ") + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING); + throw Exception("Invalid " + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING); } col->insertString(ipv6_value); @@ -111,8 +111,8 @@ public: void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory) { - factory.registerDataTypeDomain("UInt32", std::make_unique()); - factory.registerDataTypeDomain("FixedString(16)", std::make_unique()); + factory.registerDataTypeDomain("IPv4", [] { return std::make_pair(DataTypeFactory::instance().get("UInt32"), std::make_unique()); }); + factory.registerDataTypeDomain("IPv6", [] { return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"), std::make_unique()); }); } } // namespace DB diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp new file mode 100644 index 00000000000..402ce86ad62 --- /dev/null +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp @@ -0,0 +1,149 @@ +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +namespace DB { + +namespace ErrorCodes +{ + extern const int SYNTAX_ERROR; + extern const int BAD_ARGUMENTS; + extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int LOGICAL_ERROR; +} + +const std::vector supported_functions = std::vector( + {"any", "anyLast", "min", "max", "sum"}); + + +String DataTypeDomainSimpleAggregateFunction::doGetName() const { + std::stringstream stream; + stream << "SimpleAggregateFunction(" << function->getName(); + + if (!parameters.empty()) { + stream << "("; + for (size_t i = 0; i < parameters.size(); ++i) { + if (i) + stream << ", "; + stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]); + } + stream << ")"; + } + + for (const auto &argument_type : argument_types) + stream << ", " << argument_type->getName(); + + stream << ")"; + return stream.str(); +} + + +static std::pair create(const ASTPtr & arguments) +{ + String function_name; + AggregateFunctionPtr function; + DataTypes argument_types; + Array params_row; + + if (!arguments || arguments->children.empty()) + throw Exception("Data type SimpleAggregateFunction requires parameters: " + "name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + if (const ASTFunction * parametric = typeid_cast(arguments->children[0].get())) + { + if (parametric->parameters) + throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR); + function_name = parametric->name; + + const ASTs & parameters = typeid_cast(*parametric->arguments).children; + params_row.resize(parameters.size()); + + for (size_t i = 0; i < parameters.size(); ++i) + { + const ASTLiteral * lit = typeid_cast(parameters[i].get()); + if (!lit) + throw Exception("Parameters to aggregate functions must be literals", + ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS); + + params_row[i] = lit->value; + } + } + else if (auto opt_name = getIdentifierName(arguments->children[0])) + { + function_name = *opt_name; + } + else if (typeid_cast(arguments->children[0].get())) + { + throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function", + ErrorCodes::BAD_ARGUMENTS); + } + else + throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function.", + ErrorCodes::BAD_ARGUMENTS); + + for (size_t i = 1; i < arguments->children.size(); ++i) + argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i])); + + if (function_name.empty()) + throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); + + function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); + + // check function + if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) { + throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","), + ErrorCodes::BAD_ARGUMENTS); + } + + DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName()); + DataTypeDomainPtr domain = std::make_unique(storage_type, function, argument_types, params_row); + + if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) { + throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(), + ErrorCodes::BAD_ARGUMENTS); + } + + return std::make_pair(storage_type, std::move(domain)); +} + +static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(const IDataTypeDomain * domain) { + if (domain == nullptr) + return nullptr; + + if (auto simple_aggr = dynamic_cast(domain)) + return simple_aggr; + + if (domain->getDomain() != nullptr) + return findSimpleAggregateFunction(domain->getDomain()); + + return nullptr; +} + +const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType) { + return findSimpleAggregateFunction(dataType->getDomain()); +} + + +void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory) +{ + factory.registerDataTypeDomain("SimpleAggregateFunction", create); +} + +} diff --git a/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h new file mode 100644 index 00000000000..70e94b1a652 --- /dev/null +++ b/dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +#include + +namespace DB +{ + +/** The type SimpleAggregateFunction(fct, type) is meant to be used in an AggregatingMergeTree. It behaves like a standard + * data type but when rows are merged, an aggregation function is applied. + * + * The aggregation function is limited to simple functions whose merge state is the final result: + * any, anyLast, min, max, sum + * + * Examples: + * + * SimpleAggregateFunction(sum, Nullable(Float64)) + * SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) + * SimpleAggregateFunction(anyLast, IPv4) + * + * Technically, a standard IDataType is instanciated and a DataTypeDomainSimpleAggregateFunction is added as domain. + */ + +class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain { +private: + const DataTypePtr storage_type; + const AggregateFunctionPtr function; + const DataTypes argument_types; + const Array parameters; + +public: + DataTypeDomainSimpleAggregateFunction(const DataTypePtr storage_type_, const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_) + : storage_type(storage_type_), function(function_), argument_types(argument_types_), parameters(parameters_) {} + + const AggregateFunctionPtr getFunction() const { return function; } + String doGetName() const override; +}; + +/// recursively follow data type domain to find a DataTypeDomainSimpleAggregateFunction +const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType); + +} diff --git a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h b/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h index 7834e9235d2..3ccb4091636 100644 --- a/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h +++ b/dbms/src/DataTypes/DataTypeDomainWithSimpleSerialization.h @@ -12,7 +12,7 @@ class IColumn; /** Simple DataTypeDomain that uses serializeText/deserializeText * for all serialization and deserialization. */ -class DataTypeDomainWithSimpleSerialization : public IDataTypeDomain +class DataTypeDomainWithSimpleSerialization : public IDataTypeDomainCustomSerialization { public: virtual ~DataTypeDomainWithSimpleSerialization() override; diff --git a/dbms/src/DataTypes/DataTypeFactory.cpp b/dbms/src/DataTypes/DataTypeFactory.cpp index a0afab890e9..a405075e884 100644 --- a/dbms/src/DataTypes/DataTypeFactory.cpp +++ b/dbms/src/DataTypes/DataTypeFactory.cpp @@ -115,19 +115,23 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator }, case_sensitiveness); } -void DataTypeFactory::registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness) +void DataTypeFactory::registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness) { - all_domains.reserve(all_domains.size() + 1); - - auto data_type = get(type_name); - setDataTypeDomain(*data_type, *domain); - - registerDataType(domain->getName(), [data_type](const ASTPtr & /*ast*/) + registerDataType(family_name, [creator](const ASTPtr & ast) { - return data_type; - }, case_sensitiveness); + auto res = creator(ast); + res.first->appendDomain(std::move(res.second)); - all_domains.emplace_back(std::move(domain)); + return res.first; + }, case_sensitiveness); +} + +void DataTypeFactory::registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness) +{ + registerDataTypeDomain(name, [creator](const ASTPtr & /*ast*/) + { + return creator(); + }, case_sensitiveness); } const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const @@ -153,11 +157,6 @@ const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE); } -void DataTypeFactory::setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain) -{ - data_type.setDomain(&domain); -} - void registerDataTypeNumbers(DataTypeFactory & factory); void registerDataTypeDecimal(DataTypeFactory & factory); void registerDataTypeDate(DataTypeFactory & factory); @@ -175,6 +174,7 @@ void registerDataTypeNested(DataTypeFactory & factory); void registerDataTypeInterval(DataTypeFactory & factory); void registerDataTypeLowCardinality(DataTypeFactory & factory); void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory); +void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory); DataTypeFactory::DataTypeFactory() @@ -196,6 +196,7 @@ DataTypeFactory::DataTypeFactory() registerDataTypeInterval(*this); registerDataTypeLowCardinality(*this); registerDataTypeDomainIPv4AndIPv6(*this); + registerDataTypeDomainSimpleAggregateFunction(*this); } DataTypeFactory::~DataTypeFactory() diff --git a/dbms/src/DataTypes/DataTypeFactory.h b/dbms/src/DataTypes/DataTypeFactory.h index c6ef100bbb7..e4a82b342d1 100644 --- a/dbms/src/DataTypes/DataTypeFactory.h +++ b/dbms/src/DataTypes/DataTypeFactory.h @@ -28,6 +28,8 @@ class DataTypeFactory final : public ext::singleton, public IFa private: using SimpleCreator = std::function; using DataTypesDictionary = std::unordered_map; + using CreatorWithDomain = std::function(const ASTPtr & parameters)>; + using SimpleCreatorWithDomain = std::function()>; public: DataTypePtr get(const String & full_name) const; @@ -40,11 +42,13 @@ public: /// Register a simple data type, that have no parameters. void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); - // Register a domain - a refinement of existing type. - void registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness = CaseSensitive); + /// Register a type family with a dynamic domain + void registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + + /// Register a simple data type domain + void registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive); private: - static void setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain); const Creator& findCreatorByName(const String & family_name) const; private: @@ -53,9 +57,6 @@ private: /// Case insensitive data types will be additionally added here with lowercased name. DataTypesDictionary case_insensitive_data_types; - // All domains are owned by factory and shared amongst DataType instances. - std::vector all_domains; - DataTypeFactory(); ~DataTypeFactory() override; diff --git a/dbms/src/DataTypes/IDataType.cpp b/dbms/src/DataTypes/IDataType.cpp index 679871dba71..0270f1d7923 100644 --- a/dbms/src/DataTypes/IDataType.cpp +++ b/dbms/src/DataTypes/IDataType.cpp @@ -142,9 +142,9 @@ void IDataType::insertDefaultInto(IColumn & column) const void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextEscaped(column, row_num, ostr, settings); + ser_domain->serializeTextEscaped(column, row_num, ostr, settings); } else { @@ -154,9 +154,9 @@ void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, W void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextEscaped(column, istr, settings); + ser_domain->deserializeTextEscaped(column, istr, settings); } else { @@ -166,9 +166,9 @@ void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, co void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextQuoted(column, row_num, ostr, settings); + ser_domain->serializeTextQuoted(column, row_num, ostr, settings); } else { @@ -178,9 +178,9 @@ void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, Wr void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextQuoted(column, istr, settings); + ser_domain->deserializeTextQuoted(column, istr, settings); } else { @@ -190,9 +190,9 @@ void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, con void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) - { - domain->serializeTextCSV(column, row_num, ostr, settings); + if (auto ser_domain = dynamic_cast(domain.get())) + { + ser_domain->serializeTextCSV(column, row_num, ostr, settings); } else { @@ -202,9 +202,9 @@ void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, Write void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextCSV(column, istr, settings); + ser_domain->deserializeTextCSV(column, istr, settings); } else { @@ -214,9 +214,9 @@ void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeText(column, row_num, ostr, settings); + ser_domain->serializeText(column, row_num, ostr, settings); } else { @@ -226,9 +226,9 @@ void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuf void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextJSON(column, row_num, ostr, settings); + ser_domain->serializeTextJSON(column, row_num, ostr, settings); } else { @@ -238,9 +238,9 @@ void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, Writ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->deserializeTextJSON(column, istr, settings); + ser_domain->deserializeTextJSON(column, istr, settings); } else { @@ -250,9 +250,9 @@ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - if (domain) + if (auto ser_domain = dynamic_cast(domain.get())) { - domain->serializeTextXML(column, row_num, ostr, settings); + ser_domain->serializeTextXML(column, row_num, ostr, settings); } else { @@ -260,13 +260,12 @@ void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, Write } } -void IDataType::setDomain(const IDataTypeDomain* const new_domain) const +void IDataType::appendDomain(DataTypeDomainPtr new_domain) const { - if (domain != nullptr) - { - throw Exception("Type " + getName() + " already has a domain.", ErrorCodes::LOGICAL_ERROR); - } - domain = new_domain; + if (domain == nullptr) + domain = std::move(new_domain); + else + domain->appendDomain(std::move(new_domain)); } } diff --git a/dbms/src/DataTypes/IDataType.h b/dbms/src/DataTypes/IDataType.h index aa253fbdc08..a95402bf20a 100644 --- a/dbms/src/DataTypes/IDataType.h +++ b/dbms/src/DataTypes/IDataType.h @@ -13,6 +13,8 @@ class ReadBuffer; class WriteBuffer; class IDataTypeDomain; +using DataTypeDomainPtr = std::unique_ptr; + class IDataType; struct FormatSettings; @@ -459,18 +461,19 @@ public: private: friend class DataTypeFactory; - /** Sets domain on existing DataType, can be considered as second phase + /** Sets domain on existing DataType or append it to existing domain, can be considered as second phase * of construction explicitly done by DataTypeFactory. - * Will throw an exception if domain is already set. */ - void setDomain(const IDataTypeDomain* newDomain) const; + void appendDomain(DataTypeDomainPtr new_domain) const; private: /** This is mutable to allow setting domain on `const IDataType` post construction, * simplifying creation of domains for all types, without them even knowing * of domain existence. */ - mutable IDataTypeDomain const* domain; + mutable DataTypeDomainPtr domain; +public: + const IDataTypeDomain * getDomain() const { return domain.get(); } }; diff --git a/dbms/src/DataTypes/IDataTypeDomain.h b/dbms/src/DataTypes/IDataTypeDomain.h index ad38e88a213..1eed8afd808 100644 --- a/dbms/src/DataTypes/IDataTypeDomain.h +++ b/dbms/src/DataTypes/IDataTypeDomain.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include namespace DB { @@ -10,21 +12,42 @@ class WriteBuffer; struct FormatSettings; class IColumn; -/** Further refinment of the properties of data type. - * - * Contains methods for serialization/deserialization. - * Implementations of this interface represent a data type domain (example: IPv4) - * which is a refinement of the exsitgin type with a name and specific text - * representation. - * - * IDataTypeDomain is totally immutable object. You can always share them. +/** Allow to customize an existing data type and set a different name. Derived class IDataTypeDomainCustomSerialization allows + * further customization of serialization/deserialization methods. See use in IPv4 and IPv6 data type domains. + * + * IDataTypeDomain can be chained for further delegation (only for getName for the moment). */ class IDataTypeDomain { +private: + mutable DataTypeDomainPtr delegate; + public: virtual ~IDataTypeDomain() {} - virtual const char* getName() const = 0; + String getName() const { + if (delegate) + return delegate->getName(); + else + return doGetName(); + } + + void appendDomain(DataTypeDomainPtr delegate_) const { + if (delegate == nullptr) + delegate = std::move(delegate_); + else + delegate->appendDomain(std::move(delegate_)); + } + + const IDataTypeDomain * getDomain() const { return delegate.get(); } + +protected: + virtual String doGetName() const = 0; +}; + +class IDataTypeDomainCustomSerialization : public IDataTypeDomain { +public: + virtual ~IDataTypeDomainCustomSerialization() {} /** Text serialization for displaying on a terminal or saving into a text file, and the like. * Without escaping or quoting. diff --git a/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference new file mode 100644 index 00000000000..fbb3d60638e --- /dev/null +++ b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.reference @@ -0,0 +1,43 @@ +0 0 +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +7 7 +8 8 +9 9 +0 0 +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +7 7 +8 8 +9 9 +SimpleAggregateFunction(sum, Float64) +0 0 +1 2 +2 4 +3 6 +4 8 +5 10 +6 12 +7 14 +8 16 +9 18 +0 0 +1 2 +2 4 +3 6 +4 8 +5 10 +6 12 +7 14 +8 16 +9 18 +1 1 2 2.2.2.2 +SimpleAggregateFunction(anyLast, Nullable(String)) SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) SimpleAggregateFunction(anyLast, IPv4) diff --git a/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql new file mode 100644 index 00000000000..f4f80033eaa --- /dev/null +++ b/dbms/tests/queries/0_stateless/00915_simple_aggregate_function.sql @@ -0,0 +1,27 @@ +-- basic test +drop table if exists test.simple; + +create table test.simple (id UInt64,val SimpleAggregateFunction(sum,Double)) engine=AggregatingMergeTree order by id; +insert into test.simple select number,number from system.numbers limit 10; + +select * from test.simple; +select * from test.simple final; +select toTypeName(val) from test.simple limit 1; + +-- merge +insert into test.simple select number,number from system.numbers limit 10; + +select * from test.simple final; + +optimize table test.simple final; +select * from test.simple; + +-- complex types +drop table if exists test.simple; + +create table test.simple (id UInt64,nullable_str SimpleAggregateFunction(anyLast,Nullable(String)),low_str SimpleAggregateFunction(anyLast,LowCardinality(Nullable(String))),ip SimpleAggregateFunction(anyLast,IPv4)) engine=AggregatingMergeTree order by id; +insert into test.simple values(1,'1','1','1.1.1.1'); +insert into test.simple values(1,null,'2','2.2.2.2'); + +select * from test.simple final; +select toTypeName(nullable_str),toTypeName(low_str),toTypeName(ip) from test.simple limit 1;