dbms: development [#CONV-2944].

This commit is contained in:
Alexey Milovidov 2011-09-25 05:07:47 +00:00
parent b979162b00
commit 13d0b57ca3
13 changed files with 217 additions and 80 deletions

View File

@ -21,6 +21,7 @@ public:
AggregateFunctionCount() : count(0) {}
String getName() const { return "count"; }
String getTypeID() const { return "count"; }
AggregateFunctionPtr cloneEmpty() const
{

View File

@ -15,12 +15,9 @@ class AggregateFunctionFactory
{
public:
AggregateFunctionFactory();
AggregateFunctionPtr get(const String & name) const;
AggregateFunctionPtr tryGet(const String & name) const;
private:
typedef std::map<String, AggregateFunctionPtr> NonParametricAggregateFunctions;
NonParametricAggregateFunctions non_parametric_aggregate_functions;
AggregateFunctionPtr get(const String & name, const DataTypes & argument_types) const;
AggregateFunctionPtr tryGet(const String & name, const DataTypes & argument_types) const;
AggregateFunctionPtr getByTypeID(const String & type_id) const;
};
using Poco::SharedPtr;

View File

@ -0,0 +1,97 @@
#pragma once
#include <DB/IO/WriteHelpers.h>
#include <DB/IO/ReadHelpers.h>
#include <DB/DataTypes/DataTypesNumberVariable.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
namespace DB
{
template <typename T> struct AggregateFunctionSumTraits;
template <> struct AggregateFunctionSumTraits<UInt64>
{
static DataTypePtr getReturnType() { return new DataTypeVarUInt; }
static void write(UInt64 x, WriteBuffer & buf) { writeVarUInt(x, buf); }
static void read(UInt64 & x, ReadBuffer & buf) { readVarUInt(x, buf); }
};
template <> struct AggregateFunctionSumTraits<Int64>
{
static DataTypePtr getReturnType() { return new DataTypeVarInt; }
static void write(Int64 x, WriteBuffer & buf) { writeVarInt(x, buf); }
static void read(Int64 & x, ReadBuffer & buf) { readVarInt(x, buf); }
};
template <> struct AggregateFunctionSumTraits<Float64>
{
static DataTypePtr getReturnType() { return new DataTypeFloat64; }
static void write(Float64 x, WriteBuffer & buf) { writeFloatBinary(x, buf); }
static void read(Float64 & x, ReadBuffer & buf) { readFloatBinary(x, buf); }
};
/// Считает сумму чисел. Параметром шаблона может быть UInt64, Int64 или Float64.
template <typename T>
class AggregateFunctionSum : public IUnaryAggregateFunction
{
private:
T sum;
public:
AggregateFunctionSum() : sum(0) {}
String getName() const { return "sum"; }
String getTypeID() const { return "sum_" + TypeName<T>::get(); }
AggregateFunctionPtr cloneEmpty() const
{
return new AggregateFunctionSum<T>;
}
DataTypePtr getReturnType() const
{
return AggregateFunctionSumTraits<T>::getReturnType();
}
void setArgument(const DataTypePtr & argument)
{
if (!argument->isNumeric())
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
void addOne(const Field & value)
{
sum += boost::get<T>(value);
}
void merge(const IAggregateFunction & rhs)
{
sum += static_cast<const AggregateFunctionSum<T> &>(rhs).sum;
}
void serialize(WriteBuffer & buf) const
{
AggregateFunctionSumTraits<T>::write(sum, buf);
}
void deserializeMerge(ReadBuffer & buf)
{
T tmp;
AggregateFunctionSumTraits<T>::read(tmp, buf);
sum += tmp;
}
Field getResult() const
{
return sum;
}
};
}

View File

@ -18,6 +18,9 @@ public:
/// Получить основное имя функции.
virtual String getName() const = 0;
/// Получить строку, по которой можно потом будет создать объект того же типа (с помощью AggregateFunctionFactory)
virtual String getTypeID() const = 0;
/// Создать новую агрегатную функцию того же типа.
virtual SharedPtr<IAggregateFunction> cloneEmpty() const = 0;

View File

@ -19,6 +19,8 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
virtual void setArgument(const DataTypePtr & argument) = 0;
/// Добавить значение.
void add(const Row & row)
{

View File

@ -21,7 +21,6 @@ public:
Expression(ASTPtr ast_, const Context & context_) : ast(ast_), context(context_)
{
addSemantic(ast);
checkTypes(ast);
glueTree(ast);
}
@ -72,15 +71,12 @@ private:
/** Для узлов - звёздочек - раскрыть их в список всех столбцов.
* Для узлов - литералов - прописать их типы данных.
* Для узлов - функций - прописать ссылки на функции и заменить имена на канонические.
* Для узлов - функций - прописать ссылки на функции, заменить имена на канонические, прописать и проверить типы.
* Для узлов - идентификаторов - прописать ссылки на их типы.
* Проверить, что все функции применимы для типов их аргументов.
*/
void addSemantic(ASTPtr & ast);
/** Проверить, что все функции применимы.
*/
void checkTypes(ASTPtr ast);
/** Склеить одинаковые узлы в синтаксическом дереве (превращая его в направленный ациклический граф).
* Это означает, в том числе то, что функции с одними и теми же аргументами, будут выполняться только один раз.
* Например, выражение rand(), rand() вернёт два идентичных столбца.

View File

@ -1,6 +1,5 @@
#include <boost/assign/list_inserter.hpp>
#include <DB/AggregateFunctions/AggregateFunctionCount.h>
#include <DB/AggregateFunctions/AggregateFunctionSum.h>
#include <DB/AggregateFunctions/AggregateFunctionFactory.h>
@ -11,28 +10,75 @@ namespace DB
AggregateFunctionFactory::AggregateFunctionFactory()
{
boost::assign::insert(non_parametric_aggregate_functions)
("count", new AggregateFunctionCount)
;
}
AggregateFunctionPtr AggregateFunctionFactory::get(const String & name) const
AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const DataTypes & argument_types) const
{
NonParametricAggregateFunctions::const_iterator it = non_parametric_aggregate_functions.find(name);
if (it != non_parametric_aggregate_functions.end())
return it->second->cloneEmpty();
if (name == "count")
{
return new AggregateFunctionCount;
}
else if (name == "sum")
{
if (argument_types.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
String argument_type_name = argument_types[0]->getName();
if (argument_type_name == "UInt8" || argument_type_name == "UInt16"
|| argument_type_name == "UInt32" || argument_type_name == "UInt64"
|| argument_type_name == "VarUInt")
return new AggregateFunctionSum<UInt64>;
else if (argument_type_name == "Int8" || argument_type_name == "Int16"
|| argument_type_name == "Int32" || argument_type_name == "Int64"
|| argument_type_name == "VarInt")
return new AggregateFunctionSum<Int64>;
else if (argument_type_name == "Float32" || argument_type_name == "Float64")
return new AggregateFunctionSum<Float64>;
else
throw Exception("Illegal type " + argument_type_name + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
else
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
}
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name) const
AggregateFunctionPtr AggregateFunctionFactory::getByTypeID(const String & type_id) const
{
NonParametricAggregateFunctions::const_iterator it = non_parametric_aggregate_functions.find(name);
if (it != non_parametric_aggregate_functions.end())
return it->second->cloneEmpty();
return NULL;
if (type_id == "count")
return new AggregateFunctionCount;
else if (0 == type_id.compare(0, strlen("sum_"), "sum_"))
{
if (0 == type_id.compare(strlen("sum_"), strlen("UInt64"), "UInt64"))
return new AggregateFunctionSum<UInt64>;
else if (0 == type_id.compare(strlen("sum_"), strlen("Int64"), "Int64"))
return new AggregateFunctionSum<Int64>;
else if (0 == type_id.compare(strlen("sum_"), strlen("Float64"), "Float64"))
return new AggregateFunctionSum<Float64>;
else
throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
}
else
throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
}
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types) const
{
AggregateFunctionPtr res;
try
{
return get(name, argument_types);
}
catch (const DB::Exception & e)
{
if (e.code() == ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION)
return NULL;
else
throw;
}
}

View File

@ -44,12 +44,12 @@ Block IProfilingBlockInputStream::read()
if (res)
info.update(res);
if (res)
/* if (res)
{
std::cerr << std::endl;
std::cerr << getName() << std::endl;
getInfo().print(std::cerr);
}
}*/
return res;
}

View File

@ -96,7 +96,9 @@ int main(int argc, char ** argv)
DB::AggregateFunctionFactory factory;
DB::AggregateDescriptions aggregate_descriptions(1);
aggregate_descriptions[0].function = factory.get("count");
DB::DataTypes empty_list_of_types;
aggregate_descriptions[0].function = factory.get("count", empty_list_of_types);
Poco::SharedPtr<DB::DataTypes> result_types = new DB::DataTypes;
boost::assign::push_back(*result_types)

View File

@ -16,7 +16,7 @@ using Poco::SharedPtr;
void DataTypeAggregateFunction::serializeBinary(const Field & field, WriteBuffer & ostr) const
{
const AggregateFunctionPtr & value = boost::get<const AggregateFunctionPtr &>(field);
writeStringBinary(value->getName(), ostr);
writeStringBinary(value->getTypeID(), ostr);
value->serialize(ostr);
}
@ -24,7 +24,7 @@ void DataTypeAggregateFunction::deserializeBinary(Field & field, ReadBuffer & is
{
String name;
readStringBinary(name, istr);
AggregateFunctionPtr value = factory->get(name);
AggregateFunctionPtr value = factory->getByTypeID(name);
value->deserializeMerge(istr);
field = value;
}
@ -36,7 +36,7 @@ void DataTypeAggregateFunction::serializeBinary(const IColumn & column, WriteBuf
String name;
if (!vec.empty())
name = vec[0]->getName();
name = vec[0]->getTypeID();
for (ColumnAggregateFunction::Container_t::const_iterator it = vec.begin(); it != vec.end(); ++it)
{
@ -59,7 +59,7 @@ void DataTypeAggregateFunction::deserializeBinary(IColumn & column, ReadBuffer &
String name;
readStringBinary(name, istr);
AggregateFunctionPtr value = factory->get(name);
AggregateFunctionPtr value = factory->getByTypeID(name);
value->deserializeMerge(istr);
vec.push_back(value);
}

View File

@ -28,12 +28,6 @@ AggregatedData Aggregator::execute(BlockInputStreamPtr stream)
typedef std::vector<Row> Rows;
Rows aggregate_arguments(aggregates_size);
for (size_t i = 0; i < aggregates_size; ++i)
{
aggregate_arguments[i].resize(aggregates[i].arguments.size());
aggregate_columns[i].resize(aggregates[i].arguments.size());
}
/// Читаем все данные
while (Block block = stream->read())
{
@ -46,6 +40,12 @@ AggregatedData Aggregator::execute(BlockInputStreamPtr stream)
if (it->arguments.empty() && !it->argument_names.empty())
for (Names::const_iterator jt = it->argument_names.begin(); jt != it->argument_names.end(); ++jt)
it->arguments.push_back(block.getPositionByName(*jt));
for (size_t i = 0; i < aggregates_size; ++i)
{
aggregate_arguments[i].resize(aggregates[i].arguments.size());
aggregate_columns[i].resize(aggregates[i].arguments.size());
}
/// Запоминаем столбцы, с которыми будем работать
for (size_t i = 0, size = keys_size; i < size; ++i)

View File

@ -16,6 +16,11 @@ namespace DB
void Expression::addSemantic(ASTPtr & ast)
{
/// Обход в глубину
for (ASTs::iterator it = ast->children.begin(); it != ast->children.end(); ++it)
addSemantic(*it);
if (dynamic_cast<ASTAsterisk *>(&*ast))
{
ASTExpressionList * all_columns = new ASTExpressionList(ast->range);
@ -26,11 +31,35 @@ void Expression::addSemantic(ASTPtr & ast)
else if (ASTFunction * node = dynamic_cast<ASTFunction *>(&*ast))
{
Functions::const_iterator it = context.functions->find(node->name);
node->aggregate_function = context.aggregate_function_factory->tryGet(node->name);
/// Типы аргументов
DataTypes argument_types;
ASTs & arguments = dynamic_cast<ASTExpressionList &>(*node->arguments).children;
for (ASTs::iterator it = arguments.begin(); it != arguments.end(); ++it)
{
if (ASTFunction * arg = dynamic_cast<ASTFunction *>(&**it))
argument_types.push_back(arg->return_type);
else if (ASTIdentifier * arg = dynamic_cast<ASTIdentifier *>(&**it))
argument_types.push_back(arg->type);
else if (ASTLiteral * arg = dynamic_cast<ASTLiteral *>(&**it))
argument_types.push_back(arg->type);
}
node->aggregate_function = context.aggregate_function_factory->tryGet(node->name, argument_types);
if (it == context.functions->end() && node->aggregate_function.isNull())
throw Exception("Unknown function " + node->name, ErrorCodes::UNKNOWN_FUNCTION);
if (it != context.functions->end())
node->function = it->second;
/// Получаем типы результата
if (node->aggregate_function)
{
node->aggregate_function->setArguments(argument_types);
node->return_type = node->aggregate_function->getReturnType();
}
else
node->return_type = node->function->getReturnType(argument_types);
}
else if (ASTIdentifier * node = dynamic_cast<ASTIdentifier *>(&*ast))
{
@ -48,44 +77,6 @@ void Expression::addSemantic(ASTPtr & ast)
{
node->type = boost::apply_visitor(FieldToDataType(), node->value);
}
for (ASTs::iterator it = ast->children.begin(); it != ast->children.end(); ++it)
addSemantic(*it);
}
void Expression::checkTypes(ASTPtr ast)
{
/// Обход в глубину
for (ASTs::iterator it = ast->children.begin(); it != ast->children.end(); ++it)
checkTypes(*it);
if (ASTFunction * node = dynamic_cast<ASTFunction *>(&*ast))
{
/// Типы аргументов
DataTypes argument_types;
ASTs & arguments = dynamic_cast<ASTExpressionList &>(*node->arguments).children;
for (ASTs::iterator it = arguments.begin(); it != arguments.end(); ++it)
{
if (ASTFunction * arg = dynamic_cast<ASTFunction *>(&**it))
argument_types.push_back(arg->return_type);
else if (ASTIdentifier * arg = dynamic_cast<ASTIdentifier *>(&**it))
argument_types.push_back(arg->type);
else if (ASTLiteral * arg = dynamic_cast<ASTLiteral *>(&**it))
argument_types.push_back(arg->type);
}
/// Получаем типы результата
if (node->aggregate_function)
{
node->aggregate_function->setArguments(argument_types);
node->return_type = node->aggregate_function->getReturnType();
}
else
node->return_type = node->function->getReturnType(argument_types);
}
}

View File

@ -92,7 +92,9 @@ int main(int argc, char ** argv)
DB::AggregateFunctionFactory factory;
DB::AggregateDescriptions aggregate_descriptions(1);
aggregate_descriptions[0].function = factory.get("count");
DB::DataTypes empty_list_of_types;
aggregate_descriptions[0].function = factory.get("count", empty_list_of_types);
DB::Aggregator aggregator(key_column_numbers, aggregate_descriptions);