support for SimpleAggregateFunction data type

This commit is contained in:
bgranvea 2019-03-08 17:49:10 +01:00
parent 80a235fdf9
commit 9d9d16e1ea
13 changed files with 454 additions and 74 deletions

View File

@ -1,6 +1,8 @@
#include <DataStreams/AggregatingSortedBlockInputStream.h>
#include <Common/typeid_cast.h>
#include <Common/StringUtils/StringUtils.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeDomainSimpleAggregateFunction.h>
namespace DB
@ -22,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
ColumnWithTypeAndName & column = header.safeGetByPosition(i);
/// We leave only states of aggregate functions.
if (!startsWith(column.type->getName(), "AggregateFunction"))
if (!dynamic_cast<const DataTypeAggregateFunction *>(column.type.get()) && !findSimpleAggregateFunction(column.type))
{
column_numbers_not_to_aggregate.push_back(i);
continue;
@ -40,7 +42,14 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
continue;
}
column_numbers_to_aggregate.push_back(i);
if (auto simple_aggr = findSimpleAggregateFunction(column.type)) {
// simple aggregate function
SimpleAggregateDescription desc{simple_aggr->getFunction(), i};
columns_to_simple_aggregate.emplace_back(std::move(desc));
} else {
// standard aggregate function
column_numbers_to_aggregate.push_back(i);
}
}
}
@ -90,8 +99,11 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
key_differs = next_key != current_key;
/// if there are enough rows accumulated and the last one is calculated completely
if (key_differs && merged_rows >= max_block_size)
if (key_differs && merged_rows >= max_block_size) {
/// Write the simple aggregation result for the previous group.
insertSimpleAggregationResult(merged_columns);
return;
}
queue.pop();
@ -110,6 +122,14 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
for (auto & column_to_aggregate : columns_to_aggregate)
column_to_aggregate->insertDefault();
/// Write the simple aggregation result for the previous group.
if (merged_rows > 0)
insertSimpleAggregationResult(merged_columns);
/// Reset simple aggregation states for next row
for (auto & desc : columns_to_simple_aggregate)
desc.createState();
++merged_rows;
}
@ -127,6 +147,9 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
}
}
/// Write the simple aggregation result for the previous group.
insertSimpleAggregationResult(merged_columns);
finished = true;
}
@ -138,6 +161,21 @@ void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor)
size_t j = column_numbers_to_aggregate[i];
columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos);
}
for (auto & desc : columns_to_simple_aggregate)
{
auto & col = cursor->all_columns[desc.column_number];
desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr);
}
}
void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns)
{
for (auto & desc : columns_to_simple_aggregate)
{
desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]);
desc.destroyState();
}
}
}

View File

@ -7,6 +7,7 @@
#include <DataStreams/MergingSortedBlockInputStream.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/AlignedBuffer.h>
namespace DB
@ -38,10 +39,13 @@ private:
/// Read finished.
bool finished = false;
struct SimpleAggregateDescription;
/// Columns with which numbers should be aggregated.
ColumnNumbers column_numbers_to_aggregate;
ColumnNumbers column_numbers_not_to_aggregate;
std::vector<ColumnAggregateFunction *> columns_to_aggregate;
std::vector<SimpleAggregateDescription> columns_to_simple_aggregate;
RowRef current_key; /// The current primary key.
RowRef next_key; /// The primary key of the next row.
@ -54,6 +58,53 @@ private:
/** Extract all states of aggregate functions and merge them with the current group.
*/
void addRow(SortCursor & cursor);
/** Insert all values of current row for simple aggregate functions
*/
void insertSimpleAggregationResult(MutableColumns & merged_columns);
/// Stores information for aggregation of SimpleAggregateFunction columns
struct SimpleAggregateDescription
{
/// An aggregate function 'anyLast', 'sum'...
AggregateFunctionPtr function;
IAggregateFunction::AddFunc add_function;
size_t column_number;
AlignedBuffer state;
bool created = false;
SimpleAggregateDescription(const AggregateFunctionPtr & function_, const size_t column_number_) : function(function_), column_number(column_number_)
{
add_function = function->getAddressOfAddFunction();
state.reset(function->sizeOfData(), function->alignOfData());
}
void createState()
{
if (created)
return;
function->create(state.data());
created = true;
}
void destroyState()
{
if (!created)
return;
function->destroy(state.data());
created = false;
}
/// Explicitly destroy aggregation state if the stream is terminated
~SimpleAggregateDescription()
{
destroyState();
}
SimpleAggregateDescription() = default;
SimpleAggregateDescription(SimpleAggregateDescription &&) = default;
SimpleAggregateDescription(const SimpleAggregateDescription &) = delete;
};
};
}

View File

@ -23,7 +23,7 @@ namespace
class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization
{
public:
const char * getName() const override
String doGetName() const override
{
return "IPv4";
}
@ -33,7 +33,7 @@ public:
const auto col = checkAndGetColumn<ColumnUInt32>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception(getName() + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -48,7 +48,7 @@ public:
ColumnUInt32 * col = typeid_cast<ColumnUInt32 *>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception(getName() + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -66,7 +66,7 @@ public:
class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization
{
public:
const char * getName() const override
String doGetName() const override
{
return "IPv6";
}
@ -76,7 +76,7 @@ public:
const auto col = checkAndGetColumn<ColumnFixedString>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception(getName() + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -91,7 +91,7 @@ public:
ColumnFixedString * col = typeid_cast<ColumnFixedString *>(&column);
if (!col)
{
throw Exception(String(getName()) + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception(getName() + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
}
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
@ -100,7 +100,7 @@ public:
std::string ipv6_value(IPV6_BINARY_LENGTH, '\0');
if (!parseIPv6(buffer, reinterpret_cast<unsigned char *>(ipv6_value.data())))
{
throw Exception(String("Invalid ") + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
throw Exception("Invalid " + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
}
col->insertString(ipv6_value);
@ -111,8 +111,8 @@ public:
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory)
{
factory.registerDataTypeDomain("UInt32", std::make_unique<DataTypeDomainIPv4>());
factory.registerDataTypeDomain("FixedString(16)", std::make_unique<DataTypeDomainIPv6>());
factory.registerDataTypeDomain("IPv4", [] { return std::make_pair(DataTypeFactory::instance().get("UInt32"), std::make_unique<DataTypeDomainIPv4>()); });
factory.registerDataTypeDomain("IPv6", [] { return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"), std::make_unique<DataTypeDomainIPv6>()); });
}
} // namespace DB

View File

@ -0,0 +1,149 @@
#include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h>
#include <IO/ReadHelpers.h>
#include <Columns/ColumnAggregateFunction.h>
#include <DataTypes/DataTypeDomainSimpleAggregateFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFactory.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTIdentifier.h>
#include <boost/algorithm/string/join.hpp>
namespace DB {
namespace ErrorCodes
{
extern const int SYNTAX_ERROR;
extern const int BAD_ARGUMENTS;
extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int LOGICAL_ERROR;
}
const std::vector<String> supported_functions = std::vector<String>(
{"any", "anyLast", "min", "max", "sum"});
String DataTypeDomainSimpleAggregateFunction::doGetName() const {
std::stringstream stream;
stream << "SimpleAggregateFunction(" << function->getName();
if (!parameters.empty()) {
stream << "(";
for (size_t i = 0; i < parameters.size(); ++i) {
if (i)
stream << ", ";
stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]);
}
stream << ")";
}
for (const auto &argument_type : argument_types)
stream << ", " << argument_type->getName();
stream << ")";
return stream.str();
}
static std::pair<DataTypePtr, DataTypeDomainPtr> create(const ASTPtr & arguments)
{
String function_name;
AggregateFunctionPtr function;
DataTypes argument_types;
Array params_row;
if (!arguments || arguments->children.empty())
throw Exception("Data type SimpleAggregateFunction requires parameters: "
"name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (const ASTFunction * parametric = typeid_cast<const ASTFunction *>(arguments->children[0].get()))
{
if (parametric->parameters)
throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR);
function_name = parametric->name;
const ASTs & parameters = typeid_cast<const ASTExpressionList &>(*parametric->arguments).children;
params_row.resize(parameters.size());
for (size_t i = 0; i < parameters.size(); ++i)
{
const ASTLiteral * lit = typeid_cast<const ASTLiteral *>(parameters[i].get());
if (!lit)
throw Exception("Parameters to aggregate functions must be literals",
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
params_row[i] = lit->value;
}
}
else if (auto opt_name = getIdentifierName(arguments->children[0]))
{
function_name = *opt_name;
}
else if (typeid_cast<ASTLiteral *>(arguments->children[0].get()))
{
throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function",
ErrorCodes::BAD_ARGUMENTS);
}
else
throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function.",
ErrorCodes::BAD_ARGUMENTS);
for (size_t i = 1; i < arguments->children.size(); ++i)
argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i]));
if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
// check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) {
throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","),
ErrorCodes::BAD_ARGUMENTS);
}
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
DataTypeDomainPtr domain = std::make_unique<DataTypeDomainSimpleAggregateFunction>(storage_type, function, argument_types, params_row);
if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) {
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(),
ErrorCodes::BAD_ARGUMENTS);
}
return std::make_pair(storage_type, std::move(domain));
}
static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(const IDataTypeDomain * domain) {
if (domain == nullptr)
return nullptr;
if (auto simple_aggr = dynamic_cast<const DataTypeDomainSimpleAggregateFunction *>(domain))
return simple_aggr;
if (domain->getDomain() != nullptr)
return findSimpleAggregateFunction(domain->getDomain());
return nullptr;
}
const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType) {
return findSimpleAggregateFunction(dataType->getDomain());
}
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory)
{
factory.registerDataTypeDomain("SimpleAggregateFunction", create);
}
}

View File

@ -0,0 +1,45 @@
#pragma once
#include <DataTypes/IDataTypeDomain.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Common/FieldVisitors.h>
#include <IO/ReadHelpers.h>
namespace DB
{
/** The type SimpleAggregateFunction(fct, type) is meant to be used in an AggregatingMergeTree. It behaves like a standard
* data type but when rows are merged, an aggregation function is applied.
*
* The aggregation function is limited to simple functions whose merge state is the final result:
* any, anyLast, min, max, sum
*
* Examples:
*
* SimpleAggregateFunction(sum, Nullable(Float64))
* SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String)))
* SimpleAggregateFunction(anyLast, IPv4)
*
* Technically, a standard IDataType is instanciated and a DataTypeDomainSimpleAggregateFunction is added as domain.
*/
class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain {
private:
const DataTypePtr storage_type;
const AggregateFunctionPtr function;
const DataTypes argument_types;
const Array parameters;
public:
DataTypeDomainSimpleAggregateFunction(const DataTypePtr storage_type_, const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_)
: storage_type(storage_type_), function(function_), argument_types(argument_types_), parameters(parameters_) {}
const AggregateFunctionPtr getFunction() const { return function; }
String doGetName() const override;
};
/// recursively follow data type domain to find a DataTypeDomainSimpleAggregateFunction
const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType);
}

View File

@ -12,7 +12,7 @@ class IColumn;
/** Simple DataTypeDomain that uses serializeText/deserializeText
* for all serialization and deserialization. */
class DataTypeDomainWithSimpleSerialization : public IDataTypeDomain
class DataTypeDomainWithSimpleSerialization : public IDataTypeDomainCustomSerialization
{
public:
virtual ~DataTypeDomainWithSimpleSerialization() override;

View File

@ -115,19 +115,23 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator
}, case_sensitiveness);
}
void DataTypeFactory::registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness)
void DataTypeFactory::registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness)
{
all_domains.reserve(all_domains.size() + 1);
auto data_type = get(type_name);
setDataTypeDomain(*data_type, *domain);
registerDataType(domain->getName(), [data_type](const ASTPtr & /*ast*/)
registerDataType(family_name, [creator](const ASTPtr & ast)
{
return data_type;
}, case_sensitiveness);
auto res = creator(ast);
res.first->appendDomain(std::move(res.second));
all_domains.emplace_back(std::move(domain));
return res.first;
}, case_sensitiveness);
}
void DataTypeFactory::registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness)
{
registerDataTypeDomain(name, [creator](const ASTPtr & /*ast*/)
{
return creator();
}, case_sensitiveness);
}
const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const
@ -153,11 +157,6 @@ const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
}
void DataTypeFactory::setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain)
{
data_type.setDomain(&domain);
}
void registerDataTypeNumbers(DataTypeFactory & factory);
void registerDataTypeDecimal(DataTypeFactory & factory);
void registerDataTypeDate(DataTypeFactory & factory);
@ -175,6 +174,7 @@ void registerDataTypeNested(DataTypeFactory & factory);
void registerDataTypeInterval(DataTypeFactory & factory);
void registerDataTypeLowCardinality(DataTypeFactory & factory);
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory);
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory);
DataTypeFactory::DataTypeFactory()
@ -196,6 +196,7 @@ DataTypeFactory::DataTypeFactory()
registerDataTypeInterval(*this);
registerDataTypeLowCardinality(*this);
registerDataTypeDomainIPv4AndIPv6(*this);
registerDataTypeDomainSimpleAggregateFunction(*this);
}
DataTypeFactory::~DataTypeFactory()

View File

@ -28,6 +28,8 @@ class DataTypeFactory final : public ext::singleton<DataTypeFactory>, public IFa
private:
using SimpleCreator = std::function<DataTypePtr()>;
using DataTypesDictionary = std::unordered_map<String, Creator>;
using CreatorWithDomain = std::function<std::pair<DataTypePtr,DataTypeDomainPtr>(const ASTPtr & parameters)>;
using SimpleCreatorWithDomain = std::function<std::pair<DataTypePtr,DataTypeDomainPtr>()>;
public:
DataTypePtr get(const String & full_name) const;
@ -40,11 +42,13 @@ public:
/// Register a simple data type, that have no parameters.
void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
// Register a domain - a refinement of existing type.
void registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Register a type family with a dynamic domain
void registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Register a simple data type domain
void registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
private:
static void setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain);
const Creator& findCreatorByName(const String & family_name) const;
private:
@ -53,9 +57,6 @@ private:
/// Case insensitive data types will be additionally added here with lowercased name.
DataTypesDictionary case_insensitive_data_types;
// All domains are owned by factory and shared amongst DataType instances.
std::vector<DataTypeDomainPtr> all_domains;
DataTypeFactory();
~DataTypeFactory() override;

View File

@ -142,9 +142,9 @@ void IDataType::insertDefaultInto(IColumn & column) const
void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->serializeTextEscaped(column, row_num, ostr, settings);
ser_domain->serializeTextEscaped(column, row_num, ostr, settings);
}
else
{
@ -154,9 +154,9 @@ void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, W
void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->deserializeTextEscaped(column, istr, settings);
ser_domain->deserializeTextEscaped(column, istr, settings);
}
else
{
@ -166,9 +166,9 @@ void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, co
void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->serializeTextQuoted(column, row_num, ostr, settings);
ser_domain->serializeTextQuoted(column, row_num, ostr, settings);
}
else
{
@ -178,9 +178,9 @@ void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, Wr
void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->deserializeTextQuoted(column, istr, settings);
ser_domain->deserializeTextQuoted(column, istr, settings);
}
else
{
@ -190,9 +190,9 @@ void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, con
void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
{
domain->serializeTextCSV(column, row_num, ostr, settings);
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
ser_domain->serializeTextCSV(column, row_num, ostr, settings);
}
else
{
@ -202,9 +202,9 @@ void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, Write
void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->deserializeTextCSV(column, istr, settings);
ser_domain->deserializeTextCSV(column, istr, settings);
}
else
{
@ -214,9 +214,9 @@ void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const
void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->serializeText(column, row_num, ostr, settings);
ser_domain->serializeText(column, row_num, ostr, settings);
}
else
{
@ -226,9 +226,9 @@ void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuf
void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->serializeTextJSON(column, row_num, ostr, settings);
ser_domain->serializeTextJSON(column, row_num, ostr, settings);
}
else
{
@ -238,9 +238,9 @@ void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, Writ
void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->deserializeTextJSON(column, istr, settings);
ser_domain->deserializeTextJSON(column, istr, settings);
}
else
{
@ -250,9 +250,9 @@ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const
void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
if (domain)
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
{
domain->serializeTextXML(column, row_num, ostr, settings);
ser_domain->serializeTextXML(column, row_num, ostr, settings);
}
else
{
@ -260,13 +260,12 @@ void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, Write
}
}
void IDataType::setDomain(const IDataTypeDomain* const new_domain) const
void IDataType::appendDomain(DataTypeDomainPtr new_domain) const
{
if (domain != nullptr)
{
throw Exception("Type " + getName() + " already has a domain.", ErrorCodes::LOGICAL_ERROR);
}
domain = new_domain;
if (domain == nullptr)
domain = std::move(new_domain);
else
domain->appendDomain(std::move(new_domain));
}
}

View File

@ -13,6 +13,8 @@ class ReadBuffer;
class WriteBuffer;
class IDataTypeDomain;
using DataTypeDomainPtr = std::unique_ptr<const IDataTypeDomain>;
class IDataType;
struct FormatSettings;
@ -459,18 +461,19 @@ public:
private:
friend class DataTypeFactory;
/** Sets domain on existing DataType, can be considered as second phase
/** Sets domain on existing DataType or append it to existing domain, can be considered as second phase
* of construction explicitly done by DataTypeFactory.
* Will throw an exception if domain is already set.
*/
void setDomain(const IDataTypeDomain* newDomain) const;
void appendDomain(DataTypeDomainPtr new_domain) const;
private:
/** This is mutable to allow setting domain on `const IDataType` post construction,
* simplifying creation of domains for all types, without them even knowing
* of domain existence.
*/
mutable IDataTypeDomain const* domain;
mutable DataTypeDomainPtr domain;
public:
const IDataTypeDomain * getDomain() const { return domain.get(); }
};

View File

@ -1,6 +1,8 @@
#pragma once
#include <cstddef>
#include <Core/Types.h>
#include <DataTypes/IDataType.h>
namespace DB
{
@ -10,21 +12,42 @@ class WriteBuffer;
struct FormatSettings;
class IColumn;
/** Further refinment of the properties of data type.
*
* Contains methods for serialization/deserialization.
* Implementations of this interface represent a data type domain (example: IPv4)
* which is a refinement of the exsitgin type with a name and specific text
* representation.
*
* IDataTypeDomain is totally immutable object. You can always share them.
/** Allow to customize an existing data type and set a different name. Derived class IDataTypeDomainCustomSerialization allows
* further customization of serialization/deserialization methods. See use in IPv4 and IPv6 data type domains.
*
* IDataTypeDomain can be chained for further delegation (only for getName for the moment).
*/
class IDataTypeDomain
{
private:
mutable DataTypeDomainPtr delegate;
public:
virtual ~IDataTypeDomain() {}
virtual const char* getName() const = 0;
String getName() const {
if (delegate)
return delegate->getName();
else
return doGetName();
}
void appendDomain(DataTypeDomainPtr delegate_) const {
if (delegate == nullptr)
delegate = std::move(delegate_);
else
delegate->appendDomain(std::move(delegate_));
}
const IDataTypeDomain * getDomain() const { return delegate.get(); }
protected:
virtual String doGetName() const = 0;
};
class IDataTypeDomainCustomSerialization : public IDataTypeDomain {
public:
virtual ~IDataTypeDomainCustomSerialization() {}
/** Text serialization for displaying on a terminal or saving into a text file, and the like.
* Without escaping or quoting.

View File

@ -0,0 +1,43 @@
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
SimpleAggregateFunction(sum, Float64)
0 0
1 2
2 4
3 6
4 8
5 10
6 12
7 14
8 16
9 18
0 0
1 2
2 4
3 6
4 8
5 10
6 12
7 14
8 16
9 18
1 1 2 2.2.2.2
SimpleAggregateFunction(anyLast, Nullable(String)) SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) SimpleAggregateFunction(anyLast, IPv4)

View File

@ -0,0 +1,27 @@
-- basic test
drop table if exists test.simple;
create table test.simple (id UInt64,val SimpleAggregateFunction(sum,Double)) engine=AggregatingMergeTree order by id;
insert into test.simple select number,number from system.numbers limit 10;
select * from test.simple;
select * from test.simple final;
select toTypeName(val) from test.simple limit 1;
-- merge
insert into test.simple select number,number from system.numbers limit 10;
select * from test.simple final;
optimize table test.simple final;
select * from test.simple;
-- complex types
drop table if exists test.simple;
create table test.simple (id UInt64,nullable_str SimpleAggregateFunction(anyLast,Nullable(String)),low_str SimpleAggregateFunction(anyLast,LowCardinality(Nullable(String))),ip SimpleAggregateFunction(anyLast,IPv4)) engine=AggregatingMergeTree order by id;
insert into test.simple values(1,'1','1','1.1.1.1');
insert into test.simple values(1,null,'2','2.2.2.2');
select * from test.simple final;
select toTypeName(nullable_str),toTypeName(low_str),toTypeName(ip) from test.simple limit 1;