mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
support for SimpleAggregateFunction data type
This commit is contained in:
parent
80a235fdf9
commit
9d9d16e1ea
@ -1,6 +1,8 @@
|
||||
#include <DataStreams/AggregatingSortedBlockInputStream.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/StringUtils/StringUtils.h>
|
||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||
#include <DataTypes/DataTypeDomainSimpleAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -22,7 +24,7 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
|
||||
ColumnWithTypeAndName & column = header.safeGetByPosition(i);
|
||||
|
||||
/// We leave only states of aggregate functions.
|
||||
if (!startsWith(column.type->getName(), "AggregateFunction"))
|
||||
if (!dynamic_cast<const DataTypeAggregateFunction *>(column.type.get()) && !findSimpleAggregateFunction(column.type))
|
||||
{
|
||||
column_numbers_not_to_aggregate.push_back(i);
|
||||
continue;
|
||||
@ -40,7 +42,14 @@ AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
|
||||
continue;
|
||||
}
|
||||
|
||||
column_numbers_to_aggregate.push_back(i);
|
||||
if (auto simple_aggr = findSimpleAggregateFunction(column.type)) {
|
||||
// simple aggregate function
|
||||
SimpleAggregateDescription desc{simple_aggr->getFunction(), i};
|
||||
columns_to_simple_aggregate.emplace_back(std::move(desc));
|
||||
} else {
|
||||
// standard aggregate function
|
||||
column_numbers_to_aggregate.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,8 +99,11 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
|
||||
key_differs = next_key != current_key;
|
||||
|
||||
/// if there are enough rows accumulated and the last one is calculated completely
|
||||
if (key_differs && merged_rows >= max_block_size)
|
||||
if (key_differs && merged_rows >= max_block_size) {
|
||||
/// Write the simple aggregation result for the previous group.
|
||||
insertSimpleAggregationResult(merged_columns);
|
||||
return;
|
||||
}
|
||||
|
||||
queue.pop();
|
||||
|
||||
@ -110,6 +122,14 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
|
||||
for (auto & column_to_aggregate : columns_to_aggregate)
|
||||
column_to_aggregate->insertDefault();
|
||||
|
||||
/// Write the simple aggregation result for the previous group.
|
||||
if (merged_rows > 0)
|
||||
insertSimpleAggregationResult(merged_columns);
|
||||
|
||||
/// Reset simple aggregation states for next row
|
||||
for (auto & desc : columns_to_simple_aggregate)
|
||||
desc.createState();
|
||||
|
||||
++merged_rows;
|
||||
}
|
||||
|
||||
@ -127,6 +147,9 @@ void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, s
|
||||
}
|
||||
}
|
||||
|
||||
/// Write the simple aggregation result for the previous group.
|
||||
insertSimpleAggregationResult(merged_columns);
|
||||
|
||||
finished = true;
|
||||
}
|
||||
|
||||
@ -138,6 +161,21 @@ void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor)
|
||||
size_t j = column_numbers_to_aggregate[i];
|
||||
columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos);
|
||||
}
|
||||
|
||||
for (auto & desc : columns_to_simple_aggregate)
|
||||
{
|
||||
auto & col = cursor->all_columns[desc.column_number];
|
||||
desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns)
|
||||
{
|
||||
for (auto & desc : columns_to_simple_aggregate)
|
||||
{
|
||||
desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]);
|
||||
desc.destroyState();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <DataStreams/MergingSortedBlockInputStream.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
#include <Common/AlignedBuffer.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -38,10 +39,13 @@ private:
|
||||
/// Read finished.
|
||||
bool finished = false;
|
||||
|
||||
struct SimpleAggregateDescription;
|
||||
|
||||
/// Columns with which numbers should be aggregated.
|
||||
ColumnNumbers column_numbers_to_aggregate;
|
||||
ColumnNumbers column_numbers_not_to_aggregate;
|
||||
std::vector<ColumnAggregateFunction *> columns_to_aggregate;
|
||||
std::vector<SimpleAggregateDescription> columns_to_simple_aggregate;
|
||||
|
||||
RowRef current_key; /// The current primary key.
|
||||
RowRef next_key; /// The primary key of the next row.
|
||||
@ -54,6 +58,53 @@ private:
|
||||
/** Extract all states of aggregate functions and merge them with the current group.
|
||||
*/
|
||||
void addRow(SortCursor & cursor);
|
||||
|
||||
/** Insert all values of current row for simple aggregate functions
|
||||
*/
|
||||
void insertSimpleAggregationResult(MutableColumns & merged_columns);
|
||||
|
||||
/// Stores information for aggregation of SimpleAggregateFunction columns
|
||||
struct SimpleAggregateDescription
|
||||
{
|
||||
/// An aggregate function 'anyLast', 'sum'...
|
||||
AggregateFunctionPtr function;
|
||||
IAggregateFunction::AddFunc add_function;
|
||||
size_t column_number;
|
||||
AlignedBuffer state;
|
||||
bool created = false;
|
||||
|
||||
SimpleAggregateDescription(const AggregateFunctionPtr & function_, const size_t column_number_) : function(function_), column_number(column_number_)
|
||||
{
|
||||
add_function = function->getAddressOfAddFunction();
|
||||
state.reset(function->sizeOfData(), function->alignOfData());
|
||||
}
|
||||
|
||||
void createState()
|
||||
{
|
||||
if (created)
|
||||
return;
|
||||
function->create(state.data());
|
||||
created = true;
|
||||
}
|
||||
|
||||
void destroyState()
|
||||
{
|
||||
if (!created)
|
||||
return;
|
||||
function->destroy(state.data());
|
||||
created = false;
|
||||
}
|
||||
|
||||
/// Explicitly destroy aggregation state if the stream is terminated
|
||||
~SimpleAggregateDescription()
|
||||
{
|
||||
destroyState();
|
||||
}
|
||||
|
||||
SimpleAggregateDescription() = default;
|
||||
SimpleAggregateDescription(SimpleAggregateDescription &&) = default;
|
||||
SimpleAggregateDescription(const SimpleAggregateDescription &) = delete;
|
||||
};
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ namespace
|
||||
class DataTypeDomainIPv4 : public DataTypeDomainWithSimpleSerialization
|
||||
{
|
||||
public:
|
||||
const char * getName() const override
|
||||
String doGetName() const override
|
||||
{
|
||||
return "IPv4";
|
||||
}
|
||||
@ -33,7 +33,7 @@ public:
|
||||
const auto col = checkAndGetColumn<ColumnUInt32>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception(getName() + " domain can only serialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -48,7 +48,7 @@ public:
|
||||
ColumnUInt32 * col = typeid_cast<ColumnUInt32 *>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception(getName() + " domain can only deserialize columns of type UInt32." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV4_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -66,7 +66,7 @@ public:
|
||||
class DataTypeDomainIPv6 : public DataTypeDomainWithSimpleSerialization
|
||||
{
|
||||
public:
|
||||
const char * getName() const override
|
||||
String doGetName() const override
|
||||
{
|
||||
return "IPv6";
|
||||
}
|
||||
@ -76,7 +76,7 @@ public:
|
||||
const auto col = checkAndGetColumn<ColumnFixedString>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception(getName() + " domain can only serialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -91,7 +91,7 @@ public:
|
||||
ColumnFixedString * col = typeid_cast<ColumnFixedString *>(&column);
|
||||
if (!col)
|
||||
{
|
||||
throw Exception(String(getName()) + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception(getName() + " domain can only deserialize columns of type FixedString(16)." + column.getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
char buffer[IPV6_MAX_TEXT_LENGTH + 1] = {'\0'};
|
||||
@ -100,7 +100,7 @@ public:
|
||||
std::string ipv6_value(IPV6_BINARY_LENGTH, '\0');
|
||||
if (!parseIPv6(buffer, reinterpret_cast<unsigned char *>(ipv6_value.data())))
|
||||
{
|
||||
throw Exception(String("Invalid ") + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
|
||||
throw Exception("Invalid " + getName() + " value.", ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING);
|
||||
}
|
||||
|
||||
col->insertString(ipv6_value);
|
||||
@ -111,8 +111,8 @@ public:
|
||||
|
||||
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory)
|
||||
{
|
||||
factory.registerDataTypeDomain("UInt32", std::make_unique<DataTypeDomainIPv4>());
|
||||
factory.registerDataTypeDomain("FixedString(16)", std::make_unique<DataTypeDomainIPv6>());
|
||||
factory.registerDataTypeDomain("IPv4", [] { return std::make_pair(DataTypeFactory::instance().get("UInt32"), std::make_unique<DataTypeDomainIPv4>()); });
|
||||
factory.registerDataTypeDomain("IPv6", [] { return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"), std::make_unique<DataTypeDomainIPv6>()); });
|
||||
}
|
||||
|
||||
} // namespace DB
|
||||
|
149
dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp
Normal file
149
dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.cpp
Normal file
@ -0,0 +1,149 @@
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
|
||||
#include <DataTypes/DataTypeDomainSimpleAggregateFunction.h>
|
||||
#include <DataTypes/DataTypeLowCardinality.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeFactory.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
|
||||
#include <boost/algorithm/string/join.hpp>
|
||||
|
||||
namespace DB {
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int SYNTAX_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
const std::vector<String> supported_functions = std::vector<String>(
|
||||
{"any", "anyLast", "min", "max", "sum"});
|
||||
|
||||
|
||||
String DataTypeDomainSimpleAggregateFunction::doGetName() const {
|
||||
std::stringstream stream;
|
||||
stream << "SimpleAggregateFunction(" << function->getName();
|
||||
|
||||
if (!parameters.empty()) {
|
||||
stream << "(";
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
if (i)
|
||||
stream << ", ";
|
||||
stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]);
|
||||
}
|
||||
stream << ")";
|
||||
}
|
||||
|
||||
for (const auto &argument_type : argument_types)
|
||||
stream << ", " << argument_type->getName();
|
||||
|
||||
stream << ")";
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
|
||||
static std::pair<DataTypePtr, DataTypeDomainPtr> create(const ASTPtr & arguments)
|
||||
{
|
||||
String function_name;
|
||||
AggregateFunctionPtr function;
|
||||
DataTypes argument_types;
|
||||
Array params_row;
|
||||
|
||||
if (!arguments || arguments->children.empty())
|
||||
throw Exception("Data type SimpleAggregateFunction requires parameters: "
|
||||
"name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (const ASTFunction * parametric = typeid_cast<const ASTFunction *>(arguments->children[0].get()))
|
||||
{
|
||||
if (parametric->parameters)
|
||||
throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR);
|
||||
function_name = parametric->name;
|
||||
|
||||
const ASTs & parameters = typeid_cast<const ASTExpressionList &>(*parametric->arguments).children;
|
||||
params_row.resize(parameters.size());
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i)
|
||||
{
|
||||
const ASTLiteral * lit = typeid_cast<const ASTLiteral *>(parameters[i].get());
|
||||
if (!lit)
|
||||
throw Exception("Parameters to aggregate functions must be literals",
|
||||
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
|
||||
|
||||
params_row[i] = lit->value;
|
||||
}
|
||||
}
|
||||
else if (auto opt_name = getIdentifierName(arguments->children[0]))
|
||||
{
|
||||
function_name = *opt_name;
|
||||
}
|
||||
else if (typeid_cast<ASTLiteral *>(arguments->children[0].get()))
|
||||
{
|
||||
throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
else
|
||||
throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
for (size_t i = 1; i < arguments->children.size(); ++i)
|
||||
argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i]));
|
||||
|
||||
if (function_name.empty())
|
||||
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
|
||||
|
||||
// check function
|
||||
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) {
|
||||
throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","),
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
|
||||
DataTypeDomainPtr domain = std::make_unique<DataTypeDomainSimpleAggregateFunction>(storage_type, function, argument_types, params_row);
|
||||
|
||||
if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) {
|
||||
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(),
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
return std::make_pair(storage_type, std::move(domain));
|
||||
}
|
||||
|
||||
static const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(const IDataTypeDomain * domain) {
|
||||
if (domain == nullptr)
|
||||
return nullptr;
|
||||
|
||||
if (auto simple_aggr = dynamic_cast<const DataTypeDomainSimpleAggregateFunction *>(domain))
|
||||
return simple_aggr;
|
||||
|
||||
if (domain->getDomain() != nullptr)
|
||||
return findSimpleAggregateFunction(domain->getDomain());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType) {
|
||||
return findSimpleAggregateFunction(dataType->getDomain());
|
||||
}
|
||||
|
||||
|
||||
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory)
|
||||
{
|
||||
factory.registerDataTypeDomain("SimpleAggregateFunction", create);
|
||||
}
|
||||
|
||||
}
|
45
dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h
Normal file
45
dbms/src/DataTypes/DataTypeDomainSimpleAggregateFunction.h
Normal file
@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include <DataTypes/IDataTypeDomain.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/** The type SimpleAggregateFunction(fct, type) is meant to be used in an AggregatingMergeTree. It behaves like a standard
|
||||
* data type but when rows are merged, an aggregation function is applied.
|
||||
*
|
||||
* The aggregation function is limited to simple functions whose merge state is the final result:
|
||||
* any, anyLast, min, max, sum
|
||||
*
|
||||
* Examples:
|
||||
*
|
||||
* SimpleAggregateFunction(sum, Nullable(Float64))
|
||||
* SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String)))
|
||||
* SimpleAggregateFunction(anyLast, IPv4)
|
||||
*
|
||||
* Technically, a standard IDataType is instanciated and a DataTypeDomainSimpleAggregateFunction is added as domain.
|
||||
*/
|
||||
|
||||
class DataTypeDomainSimpleAggregateFunction : public IDataTypeDomain {
|
||||
private:
|
||||
const DataTypePtr storage_type;
|
||||
const AggregateFunctionPtr function;
|
||||
const DataTypes argument_types;
|
||||
const Array parameters;
|
||||
|
||||
public:
|
||||
DataTypeDomainSimpleAggregateFunction(const DataTypePtr storage_type_, const AggregateFunctionPtr & function_, const DataTypes & argument_types_, const Array & parameters_)
|
||||
: storage_type(storage_type_), function(function_), argument_types(argument_types_), parameters(parameters_) {}
|
||||
|
||||
const AggregateFunctionPtr getFunction() const { return function; }
|
||||
String doGetName() const override;
|
||||
};
|
||||
|
||||
/// recursively follow data type domain to find a DataTypeDomainSimpleAggregateFunction
|
||||
const DataTypeDomainSimpleAggregateFunction * findSimpleAggregateFunction(DataTypePtr dataType);
|
||||
|
||||
}
|
@ -12,7 +12,7 @@ class IColumn;
|
||||
|
||||
/** Simple DataTypeDomain that uses serializeText/deserializeText
|
||||
* for all serialization and deserialization. */
|
||||
class DataTypeDomainWithSimpleSerialization : public IDataTypeDomain
|
||||
class DataTypeDomainWithSimpleSerialization : public IDataTypeDomainCustomSerialization
|
||||
{
|
||||
public:
|
||||
virtual ~DataTypeDomainWithSimpleSerialization() override;
|
||||
|
@ -115,19 +115,23 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator
|
||||
}, case_sensitiveness);
|
||||
}
|
||||
|
||||
void DataTypeFactory::registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness)
|
||||
void DataTypeFactory::registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness)
|
||||
{
|
||||
all_domains.reserve(all_domains.size() + 1);
|
||||
|
||||
auto data_type = get(type_name);
|
||||
setDataTypeDomain(*data_type, *domain);
|
||||
|
||||
registerDataType(domain->getName(), [data_type](const ASTPtr & /*ast*/)
|
||||
registerDataType(family_name, [creator](const ASTPtr & ast)
|
||||
{
|
||||
return data_type;
|
||||
}, case_sensitiveness);
|
||||
auto res = creator(ast);
|
||||
res.first->appendDomain(std::move(res.second));
|
||||
|
||||
all_domains.emplace_back(std::move(domain));
|
||||
return res.first;
|
||||
}, case_sensitiveness);
|
||||
}
|
||||
|
||||
void DataTypeFactory::registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness)
|
||||
{
|
||||
registerDataTypeDomain(name, [creator](const ASTPtr & /*ast*/)
|
||||
{
|
||||
return creator();
|
||||
}, case_sensitiveness);
|
||||
}
|
||||
|
||||
const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const
|
||||
@ -153,11 +157,6 @@ const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String
|
||||
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
|
||||
}
|
||||
|
||||
void DataTypeFactory::setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain)
|
||||
{
|
||||
data_type.setDomain(&domain);
|
||||
}
|
||||
|
||||
void registerDataTypeNumbers(DataTypeFactory & factory);
|
||||
void registerDataTypeDecimal(DataTypeFactory & factory);
|
||||
void registerDataTypeDate(DataTypeFactory & factory);
|
||||
@ -175,6 +174,7 @@ void registerDataTypeNested(DataTypeFactory & factory);
|
||||
void registerDataTypeInterval(DataTypeFactory & factory);
|
||||
void registerDataTypeLowCardinality(DataTypeFactory & factory);
|
||||
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory);
|
||||
void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory);
|
||||
|
||||
|
||||
DataTypeFactory::DataTypeFactory()
|
||||
@ -196,6 +196,7 @@ DataTypeFactory::DataTypeFactory()
|
||||
registerDataTypeInterval(*this);
|
||||
registerDataTypeLowCardinality(*this);
|
||||
registerDataTypeDomainIPv4AndIPv6(*this);
|
||||
registerDataTypeDomainSimpleAggregateFunction(*this);
|
||||
}
|
||||
|
||||
DataTypeFactory::~DataTypeFactory()
|
||||
|
@ -28,6 +28,8 @@ class DataTypeFactory final : public ext::singleton<DataTypeFactory>, public IFa
|
||||
private:
|
||||
using SimpleCreator = std::function<DataTypePtr()>;
|
||||
using DataTypesDictionary = std::unordered_map<String, Creator>;
|
||||
using CreatorWithDomain = std::function<std::pair<DataTypePtr,DataTypeDomainPtr>(const ASTPtr & parameters)>;
|
||||
using SimpleCreatorWithDomain = std::function<std::pair<DataTypePtr,DataTypeDomainPtr>()>;
|
||||
|
||||
public:
|
||||
DataTypePtr get(const String & full_name) const;
|
||||
@ -40,11 +42,13 @@ public:
|
||||
/// Register a simple data type, that have no parameters.
|
||||
void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
|
||||
// Register a domain - a refinement of existing type.
|
||||
void registerDataTypeDomain(const String & type_name, DataTypeDomainPtr domain, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
/// Register a type family with a dynamic domain
|
||||
void registerDataTypeDomain(const String & family_name, CreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
|
||||
/// Register a simple data type domain
|
||||
void registerDataTypeDomain(const String & name, SimpleCreatorWithDomain creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
|
||||
|
||||
private:
|
||||
static void setDataTypeDomain(const IDataType & data_type, const IDataTypeDomain & domain);
|
||||
const Creator& findCreatorByName(const String & family_name) const;
|
||||
|
||||
private:
|
||||
@ -53,9 +57,6 @@ private:
|
||||
/// Case insensitive data types will be additionally added here with lowercased name.
|
||||
DataTypesDictionary case_insensitive_data_types;
|
||||
|
||||
// All domains are owned by factory and shared amongst DataType instances.
|
||||
std::vector<DataTypeDomainPtr> all_domains;
|
||||
|
||||
DataTypeFactory();
|
||||
~DataTypeFactory() override;
|
||||
|
||||
|
@ -142,9 +142,9 @@ void IDataType::insertDefaultInto(IColumn & column) const
|
||||
|
||||
void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->serializeTextEscaped(column, row_num, ostr, settings);
|
||||
ser_domain->serializeTextEscaped(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -154,9 +154,9 @@ void IDataType::serializeAsTextEscaped(const IColumn & column, size_t row_num, W
|
||||
|
||||
void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->deserializeTextEscaped(column, istr, settings);
|
||||
ser_domain->deserializeTextEscaped(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -166,9 +166,9 @@ void IDataType::deserializeAsTextEscaped(IColumn & column, ReadBuffer & istr, co
|
||||
|
||||
void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->serializeTextQuoted(column, row_num, ostr, settings);
|
||||
ser_domain->serializeTextQuoted(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -178,9 +178,9 @@ void IDataType::serializeAsTextQuoted(const IColumn & column, size_t row_num, Wr
|
||||
|
||||
void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->deserializeTextQuoted(column, istr, settings);
|
||||
ser_domain->deserializeTextQuoted(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -190,9 +190,9 @@ void IDataType::deserializeAsTextQuoted(IColumn & column, ReadBuffer & istr, con
|
||||
|
||||
void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
{
|
||||
domain->serializeTextCSV(column, row_num, ostr, settings);
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
ser_domain->serializeTextCSV(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -202,9 +202,9 @@ void IDataType::serializeAsTextCSV(const IColumn & column, size_t row_num, Write
|
||||
|
||||
void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->deserializeTextCSV(column, istr, settings);
|
||||
ser_domain->deserializeTextCSV(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -214,9 +214,9 @@ void IDataType::deserializeAsTextCSV(IColumn & column, ReadBuffer & istr, const
|
||||
|
||||
void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->serializeText(column, row_num, ostr, settings);
|
||||
ser_domain->serializeText(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -226,9 +226,9 @@ void IDataType::serializeAsText(const IColumn & column, size_t row_num, WriteBuf
|
||||
|
||||
void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->serializeTextJSON(column, row_num, ostr, settings);
|
||||
ser_domain->serializeTextJSON(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -238,9 +238,9 @@ void IDataType::serializeAsTextJSON(const IColumn & column, size_t row_num, Writ
|
||||
|
||||
void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->deserializeTextJSON(column, istr, settings);
|
||||
ser_domain->deserializeTextJSON(column, istr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -250,9 +250,9 @@ void IDataType::deserializeAsTextJSON(IColumn & column, ReadBuffer & istr, const
|
||||
|
||||
void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
|
||||
{
|
||||
if (domain)
|
||||
if (auto ser_domain = dynamic_cast<const IDataTypeDomainCustomSerialization *>(domain.get()))
|
||||
{
|
||||
domain->serializeTextXML(column, row_num, ostr, settings);
|
||||
ser_domain->serializeTextXML(column, row_num, ostr, settings);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -260,13 +260,12 @@ void IDataType::serializeAsTextXML(const IColumn & column, size_t row_num, Write
|
||||
}
|
||||
}
|
||||
|
||||
void IDataType::setDomain(const IDataTypeDomain* const new_domain) const
|
||||
void IDataType::appendDomain(DataTypeDomainPtr new_domain) const
|
||||
{
|
||||
if (domain != nullptr)
|
||||
{
|
||||
throw Exception("Type " + getName() + " already has a domain.", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
domain = new_domain;
|
||||
if (domain == nullptr)
|
||||
domain = std::move(new_domain);
|
||||
else
|
||||
domain->appendDomain(std::move(new_domain));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -13,6 +13,8 @@ class ReadBuffer;
|
||||
class WriteBuffer;
|
||||
|
||||
class IDataTypeDomain;
|
||||
using DataTypeDomainPtr = std::unique_ptr<const IDataTypeDomain>;
|
||||
|
||||
class IDataType;
|
||||
struct FormatSettings;
|
||||
|
||||
@ -459,18 +461,19 @@ public:
|
||||
|
||||
private:
|
||||
friend class DataTypeFactory;
|
||||
/** Sets domain on existing DataType, can be considered as second phase
|
||||
/** Sets domain on existing DataType or append it to existing domain, can be considered as second phase
|
||||
* of construction explicitly done by DataTypeFactory.
|
||||
* Will throw an exception if domain is already set.
|
||||
*/
|
||||
void setDomain(const IDataTypeDomain* newDomain) const;
|
||||
void appendDomain(DataTypeDomainPtr new_domain) const;
|
||||
|
||||
private:
|
||||
/** This is mutable to allow setting domain on `const IDataType` post construction,
|
||||
* simplifying creation of domains for all types, without them even knowing
|
||||
* of domain existence.
|
||||
*/
|
||||
mutable IDataTypeDomain const* domain;
|
||||
mutable DataTypeDomainPtr domain;
|
||||
public:
|
||||
const IDataTypeDomain * getDomain() const { return domain.get(); }
|
||||
};
|
||||
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <Core/Types.h>
|
||||
#include <DataTypes/IDataType.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,21 +12,42 @@ class WriteBuffer;
|
||||
struct FormatSettings;
|
||||
class IColumn;
|
||||
|
||||
/** Further refinment of the properties of data type.
|
||||
*
|
||||
* Contains methods for serialization/deserialization.
|
||||
* Implementations of this interface represent a data type domain (example: IPv4)
|
||||
* which is a refinement of the exsitgin type with a name and specific text
|
||||
* representation.
|
||||
*
|
||||
* IDataTypeDomain is totally immutable object. You can always share them.
|
||||
/** Allow to customize an existing data type and set a different name. Derived class IDataTypeDomainCustomSerialization allows
|
||||
* further customization of serialization/deserialization methods. See use in IPv4 and IPv6 data type domains.
|
||||
*
|
||||
* IDataTypeDomain can be chained for further delegation (only for getName for the moment).
|
||||
*/
|
||||
class IDataTypeDomain
|
||||
{
|
||||
private:
|
||||
mutable DataTypeDomainPtr delegate;
|
||||
|
||||
public:
|
||||
virtual ~IDataTypeDomain() {}
|
||||
|
||||
virtual const char* getName() const = 0;
|
||||
String getName() const {
|
||||
if (delegate)
|
||||
return delegate->getName();
|
||||
else
|
||||
return doGetName();
|
||||
}
|
||||
|
||||
void appendDomain(DataTypeDomainPtr delegate_) const {
|
||||
if (delegate == nullptr)
|
||||
delegate = std::move(delegate_);
|
||||
else
|
||||
delegate->appendDomain(std::move(delegate_));
|
||||
}
|
||||
|
||||
const IDataTypeDomain * getDomain() const { return delegate.get(); }
|
||||
|
||||
protected:
|
||||
virtual String doGetName() const = 0;
|
||||
};
|
||||
|
||||
class IDataTypeDomainCustomSerialization : public IDataTypeDomain {
|
||||
public:
|
||||
virtual ~IDataTypeDomainCustomSerialization() {}
|
||||
|
||||
/** Text serialization for displaying on a terminal or saving into a text file, and the like.
|
||||
* Without escaping or quoting.
|
||||
|
@ -0,0 +1,43 @@
|
||||
0 0
|
||||
1 1
|
||||
2 2
|
||||
3 3
|
||||
4 4
|
||||
5 5
|
||||
6 6
|
||||
7 7
|
||||
8 8
|
||||
9 9
|
||||
0 0
|
||||
1 1
|
||||
2 2
|
||||
3 3
|
||||
4 4
|
||||
5 5
|
||||
6 6
|
||||
7 7
|
||||
8 8
|
||||
9 9
|
||||
SimpleAggregateFunction(sum, Float64)
|
||||
0 0
|
||||
1 2
|
||||
2 4
|
||||
3 6
|
||||
4 8
|
||||
5 10
|
||||
6 12
|
||||
7 14
|
||||
8 16
|
||||
9 18
|
||||
0 0
|
||||
1 2
|
||||
2 4
|
||||
3 6
|
||||
4 8
|
||||
5 10
|
||||
6 12
|
||||
7 14
|
||||
8 16
|
||||
9 18
|
||||
1 1 2 2.2.2.2
|
||||
SimpleAggregateFunction(anyLast, Nullable(String)) SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) SimpleAggregateFunction(anyLast, IPv4)
|
@ -0,0 +1,27 @@
|
||||
-- basic test
|
||||
drop table if exists test.simple;
|
||||
|
||||
create table test.simple (id UInt64,val SimpleAggregateFunction(sum,Double)) engine=AggregatingMergeTree order by id;
|
||||
insert into test.simple select number,number from system.numbers limit 10;
|
||||
|
||||
select * from test.simple;
|
||||
select * from test.simple final;
|
||||
select toTypeName(val) from test.simple limit 1;
|
||||
|
||||
-- merge
|
||||
insert into test.simple select number,number from system.numbers limit 10;
|
||||
|
||||
select * from test.simple final;
|
||||
|
||||
optimize table test.simple final;
|
||||
select * from test.simple;
|
||||
|
||||
-- complex types
|
||||
drop table if exists test.simple;
|
||||
|
||||
create table test.simple (id UInt64,nullable_str SimpleAggregateFunction(anyLast,Nullable(String)),low_str SimpleAggregateFunction(anyLast,LowCardinality(Nullable(String))),ip SimpleAggregateFunction(anyLast,IPv4)) engine=AggregatingMergeTree order by id;
|
||||
insert into test.simple values(1,'1','1','1.1.1.1');
|
||||
insert into test.simple values(1,null,'2','2.2.2.2');
|
||||
|
||||
select * from test.simple final;
|
||||
select toTypeName(nullable_str),toTypeName(low_str),toTypeName(ip) from test.simple limit 1;
|
Loading…
Reference in New Issue
Block a user