mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
dbms: development [#CONV-2944].
This commit is contained in:
parent
b979162b00
commit
13d0b57ca3
@ -21,6 +21,7 @@ public:
|
||||
AggregateFunctionCount() : count(0) {}
|
||||
|
||||
String getName() const { return "count"; }
|
||||
String getTypeID() const { return "count"; }
|
||||
|
||||
AggregateFunctionPtr cloneEmpty() const
|
||||
{
|
||||
|
@ -15,12 +15,9 @@ class AggregateFunctionFactory
|
||||
{
|
||||
public:
|
||||
AggregateFunctionFactory();
|
||||
AggregateFunctionPtr get(const String & name) const;
|
||||
AggregateFunctionPtr tryGet(const String & name) const;
|
||||
|
||||
private:
|
||||
typedef std::map<String, AggregateFunctionPtr> NonParametricAggregateFunctions;
|
||||
NonParametricAggregateFunctions non_parametric_aggregate_functions;
|
||||
AggregateFunctionPtr get(const String & name, const DataTypes & argument_types) const;
|
||||
AggregateFunctionPtr tryGet(const String & name, const DataTypes & argument_types) const;
|
||||
AggregateFunctionPtr getByTypeID(const String & type_id) const;
|
||||
};
|
||||
|
||||
using Poco::SharedPtr;
|
||||
|
97
dbms/include/DB/AggregateFunctions/AggregateFunctionSum.h
Normal file
97
dbms/include/DB/AggregateFunctions/AggregateFunctionSum.h
Normal file
@ -0,0 +1,97 @@
|
||||
#pragma once
|
||||
|
||||
#include <DB/IO/WriteHelpers.h>
|
||||
#include <DB/IO/ReadHelpers.h>
|
||||
|
||||
#include <DB/DataTypes/DataTypesNumberVariable.h>
|
||||
#include <DB/DataTypes/DataTypesNumberFixed.h>
|
||||
|
||||
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <typename T> struct AggregateFunctionSumTraits;
|
||||
|
||||
template <> struct AggregateFunctionSumTraits<UInt64>
|
||||
{
|
||||
static DataTypePtr getReturnType() { return new DataTypeVarUInt; }
|
||||
static void write(UInt64 x, WriteBuffer & buf) { writeVarUInt(x, buf); }
|
||||
static void read(UInt64 & x, ReadBuffer & buf) { readVarUInt(x, buf); }
|
||||
};
|
||||
|
||||
template <> struct AggregateFunctionSumTraits<Int64>
|
||||
{
|
||||
static DataTypePtr getReturnType() { return new DataTypeVarInt; }
|
||||
static void write(Int64 x, WriteBuffer & buf) { writeVarInt(x, buf); }
|
||||
static void read(Int64 & x, ReadBuffer & buf) { readVarInt(x, buf); }
|
||||
};
|
||||
|
||||
template <> struct AggregateFunctionSumTraits<Float64>
|
||||
{
|
||||
static DataTypePtr getReturnType() { return new DataTypeFloat64; }
|
||||
static void write(Float64 x, WriteBuffer & buf) { writeFloatBinary(x, buf); }
|
||||
static void read(Float64 & x, ReadBuffer & buf) { readFloatBinary(x, buf); }
|
||||
};
|
||||
|
||||
|
||||
/// Считает сумму чисел. Параметром шаблона может быть UInt64, Int64 или Float64.
|
||||
template <typename T>
|
||||
class AggregateFunctionSum : public IUnaryAggregateFunction
|
||||
{
|
||||
private:
|
||||
T sum;
|
||||
|
||||
public:
|
||||
AggregateFunctionSum() : sum(0) {}
|
||||
|
||||
String getName() const { return "sum"; }
|
||||
String getTypeID() const { return "sum_" + TypeName<T>::get(); }
|
||||
|
||||
AggregateFunctionPtr cloneEmpty() const
|
||||
{
|
||||
return new AggregateFunctionSum<T>;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const
|
||||
{
|
||||
return AggregateFunctionSumTraits<T>::getReturnType();
|
||||
}
|
||||
|
||||
void setArgument(const DataTypePtr & argument)
|
||||
{
|
||||
if (!argument->isNumeric())
|
||||
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
void addOne(const Field & value)
|
||||
{
|
||||
sum += boost::get<T>(value);
|
||||
}
|
||||
|
||||
void merge(const IAggregateFunction & rhs)
|
||||
{
|
||||
sum += static_cast<const AggregateFunctionSum<T> &>(rhs).sum;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
AggregateFunctionSumTraits<T>::write(sum, buf);
|
||||
}
|
||||
|
||||
void deserializeMerge(ReadBuffer & buf)
|
||||
{
|
||||
T tmp;
|
||||
AggregateFunctionSumTraits<T>::read(tmp, buf);
|
||||
sum += tmp;
|
||||
}
|
||||
|
||||
Field getResult() const
|
||||
{
|
||||
return sum;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -18,6 +18,9 @@ public:
|
||||
/// Получить основное имя функции.
|
||||
virtual String getName() const = 0;
|
||||
|
||||
/// Получить строку, по которой можно потом будет создать объект того же типа (с помощью AggregateFunctionFactory)
|
||||
virtual String getTypeID() const = 0;
|
||||
|
||||
/// Создать новую агрегатную функцию того же типа.
|
||||
virtual SharedPtr<IAggregateFunction> cloneEmpty() const = 0;
|
||||
|
||||
|
@ -19,6 +19,8 @@ public:
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
}
|
||||
|
||||
virtual void setArgument(const DataTypePtr & argument) = 0;
|
||||
|
||||
/// Добавить значение.
|
||||
void add(const Row & row)
|
||||
{
|
||||
|
@ -21,7 +21,6 @@ public:
|
||||
Expression(ASTPtr ast_, const Context & context_) : ast(ast_), context(context_)
|
||||
{
|
||||
addSemantic(ast);
|
||||
checkTypes(ast);
|
||||
glueTree(ast);
|
||||
}
|
||||
|
||||
@ -72,15 +71,12 @@ private:
|
||||
|
||||
/** Для узлов - звёздочек - раскрыть их в список всех столбцов.
|
||||
* Для узлов - литералов - прописать их типы данных.
|
||||
* Для узлов - функций - прописать ссылки на функции и заменить имена на канонические.
|
||||
* Для узлов - функций - прописать ссылки на функции, заменить имена на канонические, прописать и проверить типы.
|
||||
* Для узлов - идентификаторов - прописать ссылки на их типы.
|
||||
* Проверить, что все функции применимы для типов их аргументов.
|
||||
*/
|
||||
void addSemantic(ASTPtr & ast);
|
||||
|
||||
/** Проверить, что все функции применимы.
|
||||
*/
|
||||
void checkTypes(ASTPtr ast);
|
||||
|
||||
/** Склеить одинаковые узлы в синтаксическом дереве (превращая его в направленный ациклический граф).
|
||||
* Это означает, в том числе то, что функции с одними и теми же аргументами, будут выполняться только один раз.
|
||||
* Например, выражение rand(), rand() вернёт два идентичных столбца.
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <boost/assign/list_inserter.hpp>
|
||||
|
||||
#include <DB/AggregateFunctions/AggregateFunctionCount.h>
|
||||
#include <DB/AggregateFunctions/AggregateFunctionSum.h>
|
||||
|
||||
#include <DB/AggregateFunctions/AggregateFunctionFactory.h>
|
||||
|
||||
@ -11,28 +10,75 @@ namespace DB
|
||||
|
||||
AggregateFunctionFactory::AggregateFunctionFactory()
|
||||
{
|
||||
boost::assign::insert(non_parametric_aggregate_functions)
|
||||
("count", new AggregateFunctionCount)
|
||||
;
|
||||
}
|
||||
|
||||
|
||||
AggregateFunctionPtr AggregateFunctionFactory::get(const String & name) const
|
||||
AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const DataTypes & argument_types) const
|
||||
{
|
||||
NonParametricAggregateFunctions::const_iterator it = non_parametric_aggregate_functions.find(name);
|
||||
if (it != non_parametric_aggregate_functions.end())
|
||||
return it->second->cloneEmpty();
|
||||
if (name == "count")
|
||||
{
|
||||
return new AggregateFunctionCount;
|
||||
}
|
||||
else if (name == "sum")
|
||||
{
|
||||
if (argument_types.size() != 1)
|
||||
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
||||
String argument_type_name = argument_types[0]->getName();
|
||||
|
||||
if (argument_type_name == "UInt8" || argument_type_name == "UInt16"
|
||||
|| argument_type_name == "UInt32" || argument_type_name == "UInt64"
|
||||
|| argument_type_name == "VarUInt")
|
||||
return new AggregateFunctionSum<UInt64>;
|
||||
else if (argument_type_name == "Int8" || argument_type_name == "Int16"
|
||||
|| argument_type_name == "Int32" || argument_type_name == "Int64"
|
||||
|| argument_type_name == "VarInt")
|
||||
return new AggregateFunctionSum<Int64>;
|
||||
else if (argument_type_name == "Float32" || argument_type_name == "Float64")
|
||||
return new AggregateFunctionSum<Float64>;
|
||||
else
|
||||
throw Exception("Illegal type " + argument_type_name + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
else
|
||||
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
||||
}
|
||||
|
||||
|
||||
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name) const
|
||||
AggregateFunctionPtr AggregateFunctionFactory::getByTypeID(const String & type_id) const
|
||||
{
|
||||
NonParametricAggregateFunctions::const_iterator it = non_parametric_aggregate_functions.find(name);
|
||||
if (it != non_parametric_aggregate_functions.end())
|
||||
return it->second->cloneEmpty();
|
||||
return NULL;
|
||||
if (type_id == "count")
|
||||
return new AggregateFunctionCount;
|
||||
else if (0 == type_id.compare(0, strlen("sum_"), "sum_"))
|
||||
{
|
||||
if (0 == type_id.compare(strlen("sum_"), strlen("UInt64"), "UInt64"))
|
||||
return new AggregateFunctionSum<UInt64>;
|
||||
else if (0 == type_id.compare(strlen("sum_"), strlen("Int64"), "Int64"))
|
||||
return new AggregateFunctionSum<Int64>;
|
||||
else if (0 == type_id.compare(strlen("sum_"), strlen("Float64"), "Float64"))
|
||||
return new AggregateFunctionSum<Float64>;
|
||||
else
|
||||
throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
||||
}
|
||||
else
|
||||
throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
||||
}
|
||||
|
||||
|
||||
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types) const
|
||||
{
|
||||
AggregateFunctionPtr res;
|
||||
|
||||
try
|
||||
{
|
||||
return get(name, argument_types);
|
||||
}
|
||||
catch (const DB::Exception & e)
|
||||
{
|
||||
if (e.code() == ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION)
|
||||
return NULL;
|
||||
else
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -44,12 +44,12 @@ Block IProfilingBlockInputStream::read()
|
||||
if (res)
|
||||
info.update(res);
|
||||
|
||||
if (res)
|
||||
/* if (res)
|
||||
{
|
||||
std::cerr << std::endl;
|
||||
std::cerr << getName() << std::endl;
|
||||
getInfo().print(std::cerr);
|
||||
}
|
||||
}*/
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -96,7 +96,9 @@ int main(int argc, char ** argv)
|
||||
DB::AggregateFunctionFactory factory;
|
||||
|
||||
DB::AggregateDescriptions aggregate_descriptions(1);
|
||||
aggregate_descriptions[0].function = factory.get("count");
|
||||
|
||||
DB::DataTypes empty_list_of_types;
|
||||
aggregate_descriptions[0].function = factory.get("count", empty_list_of_types);
|
||||
|
||||
Poco::SharedPtr<DB::DataTypes> result_types = new DB::DataTypes;
|
||||
boost::assign::push_back(*result_types)
|
||||
|
@ -16,7 +16,7 @@ using Poco::SharedPtr;
|
||||
void DataTypeAggregateFunction::serializeBinary(const Field & field, WriteBuffer & ostr) const
|
||||
{
|
||||
const AggregateFunctionPtr & value = boost::get<const AggregateFunctionPtr &>(field);
|
||||
writeStringBinary(value->getName(), ostr);
|
||||
writeStringBinary(value->getTypeID(), ostr);
|
||||
value->serialize(ostr);
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ void DataTypeAggregateFunction::deserializeBinary(Field & field, ReadBuffer & is
|
||||
{
|
||||
String name;
|
||||
readStringBinary(name, istr);
|
||||
AggregateFunctionPtr value = factory->get(name);
|
||||
AggregateFunctionPtr value = factory->getByTypeID(name);
|
||||
value->deserializeMerge(istr);
|
||||
field = value;
|
||||
}
|
||||
@ -36,7 +36,7 @@ void DataTypeAggregateFunction::serializeBinary(const IColumn & column, WriteBuf
|
||||
|
||||
String name;
|
||||
if (!vec.empty())
|
||||
name = vec[0]->getName();
|
||||
name = vec[0]->getTypeID();
|
||||
|
||||
for (ColumnAggregateFunction::Container_t::const_iterator it = vec.begin(); it != vec.end(); ++it)
|
||||
{
|
||||
@ -59,7 +59,7 @@ void DataTypeAggregateFunction::deserializeBinary(IColumn & column, ReadBuffer &
|
||||
|
||||
String name;
|
||||
readStringBinary(name, istr);
|
||||
AggregateFunctionPtr value = factory->get(name);
|
||||
AggregateFunctionPtr value = factory->getByTypeID(name);
|
||||
value->deserializeMerge(istr);
|
||||
vec.push_back(value);
|
||||
}
|
||||
|
@ -28,12 +28,6 @@ AggregatedData Aggregator::execute(BlockInputStreamPtr stream)
|
||||
typedef std::vector<Row> Rows;
|
||||
Rows aggregate_arguments(aggregates_size);
|
||||
|
||||
for (size_t i = 0; i < aggregates_size; ++i)
|
||||
{
|
||||
aggregate_arguments[i].resize(aggregates[i].arguments.size());
|
||||
aggregate_columns[i].resize(aggregates[i].arguments.size());
|
||||
}
|
||||
|
||||
/// Читаем все данные
|
||||
while (Block block = stream->read())
|
||||
{
|
||||
@ -46,6 +40,12 @@ AggregatedData Aggregator::execute(BlockInputStreamPtr stream)
|
||||
if (it->arguments.empty() && !it->argument_names.empty())
|
||||
for (Names::const_iterator jt = it->argument_names.begin(); jt != it->argument_names.end(); ++jt)
|
||||
it->arguments.push_back(block.getPositionByName(*jt));
|
||||
|
||||
for (size_t i = 0; i < aggregates_size; ++i)
|
||||
{
|
||||
aggregate_arguments[i].resize(aggregates[i].arguments.size());
|
||||
aggregate_columns[i].resize(aggregates[i].arguments.size());
|
||||
}
|
||||
|
||||
/// Запоминаем столбцы, с которыми будем работать
|
||||
for (size_t i = 0, size = keys_size; i < size; ++i)
|
||||
|
@ -16,6 +16,11 @@ namespace DB
|
||||
|
||||
void Expression::addSemantic(ASTPtr & ast)
|
||||
{
|
||||
/// Обход в глубину
|
||||
|
||||
for (ASTs::iterator it = ast->children.begin(); it != ast->children.end(); ++it)
|
||||
addSemantic(*it);
|
||||
|
||||
if (dynamic_cast<ASTAsterisk *>(&*ast))
|
||||
{
|
||||
ASTExpressionList * all_columns = new ASTExpressionList(ast->range);
|
||||
@ -26,11 +31,35 @@ void Expression::addSemantic(ASTPtr & ast)
|
||||
else if (ASTFunction * node = dynamic_cast<ASTFunction *>(&*ast))
|
||||
{
|
||||
Functions::const_iterator it = context.functions->find(node->name);
|
||||
node->aggregate_function = context.aggregate_function_factory->tryGet(node->name);
|
||||
|
||||
/// Типы аргументов
|
||||
DataTypes argument_types;
|
||||
ASTs & arguments = dynamic_cast<ASTExpressionList &>(*node->arguments).children;
|
||||
|
||||
for (ASTs::iterator it = arguments.begin(); it != arguments.end(); ++it)
|
||||
{
|
||||
if (ASTFunction * arg = dynamic_cast<ASTFunction *>(&**it))
|
||||
argument_types.push_back(arg->return_type);
|
||||
else if (ASTIdentifier * arg = dynamic_cast<ASTIdentifier *>(&**it))
|
||||
argument_types.push_back(arg->type);
|
||||
else if (ASTLiteral * arg = dynamic_cast<ASTLiteral *>(&**it))
|
||||
argument_types.push_back(arg->type);
|
||||
}
|
||||
|
||||
node->aggregate_function = context.aggregate_function_factory->tryGet(node->name, argument_types);
|
||||
if (it == context.functions->end() && node->aggregate_function.isNull())
|
||||
throw Exception("Unknown function " + node->name, ErrorCodes::UNKNOWN_FUNCTION);
|
||||
if (it != context.functions->end())
|
||||
node->function = it->second;
|
||||
|
||||
/// Получаем типы результата
|
||||
if (node->aggregate_function)
|
||||
{
|
||||
node->aggregate_function->setArguments(argument_types);
|
||||
node->return_type = node->aggregate_function->getReturnType();
|
||||
}
|
||||
else
|
||||
node->return_type = node->function->getReturnType(argument_types);
|
||||
}
|
||||
else if (ASTIdentifier * node = dynamic_cast<ASTIdentifier *>(&*ast))
|
||||
{
|
||||
@ -48,44 +77,6 @@ void Expression::addSemantic(ASTPtr & ast)
|
||||
{
|
||||
node->type = boost::apply_visitor(FieldToDataType(), node->value);
|
||||
}
|
||||
|
||||
for (ASTs::iterator it = ast->children.begin(); it != ast->children.end(); ++it)
|
||||
addSemantic(*it);
|
||||
}
|
||||
|
||||
|
||||
void Expression::checkTypes(ASTPtr ast)
|
||||
{
|
||||
/// Обход в глубину
|
||||
|
||||
for (ASTs::iterator it = ast->children.begin(); it != ast->children.end(); ++it)
|
||||
checkTypes(*it);
|
||||
|
||||
if (ASTFunction * node = dynamic_cast<ASTFunction *>(&*ast))
|
||||
{
|
||||
/// Типы аргументов
|
||||
DataTypes argument_types;
|
||||
ASTs & arguments = dynamic_cast<ASTExpressionList &>(*node->arguments).children;
|
||||
|
||||
for (ASTs::iterator it = arguments.begin(); it != arguments.end(); ++it)
|
||||
{
|
||||
if (ASTFunction * arg = dynamic_cast<ASTFunction *>(&**it))
|
||||
argument_types.push_back(arg->return_type);
|
||||
else if (ASTIdentifier * arg = dynamic_cast<ASTIdentifier *>(&**it))
|
||||
argument_types.push_back(arg->type);
|
||||
else if (ASTLiteral * arg = dynamic_cast<ASTLiteral *>(&**it))
|
||||
argument_types.push_back(arg->type);
|
||||
}
|
||||
|
||||
/// Получаем типы результата
|
||||
if (node->aggregate_function)
|
||||
{
|
||||
node->aggregate_function->setArguments(argument_types);
|
||||
node->return_type = node->aggregate_function->getReturnType();
|
||||
}
|
||||
else
|
||||
node->return_type = node->function->getReturnType(argument_types);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -92,7 +92,9 @@ int main(int argc, char ** argv)
|
||||
DB::AggregateFunctionFactory factory;
|
||||
|
||||
DB::AggregateDescriptions aggregate_descriptions(1);
|
||||
aggregate_descriptions[0].function = factory.get("count");
|
||||
|
||||
DB::DataTypes empty_list_of_types;
|
||||
aggregate_descriptions[0].function = factory.get("count", empty_list_of_types);
|
||||
|
||||
DB::Aggregator aggregator(key_column_numbers, aggregate_descriptions);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user