diff --git a/dbms/include/DB/Columns/ColumnAggregateFunction.h b/dbms/include/DB/Columns/ColumnAggregateFunction.h index 683f09e65f4..be3ba5d8ed3 100644 --- a/dbms/include/DB/Columns/ColumnAggregateFunction.h +++ b/dbms/include/DB/Columns/ColumnAggregateFunction.h @@ -42,16 +42,13 @@ public: void insert(const Field & x) { - data.push_back(boost::get(x)); + data.push_back(boost::get(x)); } int compareAt(size_t n, size_t m, const IColumn & rhs_) const { return 0; } - -private: - Container_t data; }; diff --git a/dbms/include/DB/Columns/ColumnVector.h b/dbms/include/DB/Columns/ColumnVector.h index 2db73d88ceb..b7ff5df2555 100644 --- a/dbms/include/DB/Columns/ColumnVector.h +++ b/dbms/include/DB/Columns/ColumnVector.h @@ -135,7 +135,7 @@ public: return data; } -private: +protected: Container_t data; }; diff --git a/dbms/include/DB/Core/Block.h b/dbms/include/DB/Core/Block.h index efe9cdf4d4e..a9ee07e2b1b 100644 --- a/dbms/include/DB/Core/Block.h +++ b/dbms/include/DB/Core/Block.h @@ -67,6 +67,9 @@ public: /** Получить список имён столбцов через запятую. */ std::string dumpNames() const; + + /** Получить такой же блок, но пустой. */ + Block cloneEmpty() const; }; } diff --git a/dbms/include/DB/Core/ColumnWithNameAndType.h b/dbms/include/DB/Core/ColumnWithNameAndType.h index af56cda6614..efcb20d9e75 100644 --- a/dbms/include/DB/Core/ColumnWithNameAndType.h +++ b/dbms/include/DB/Core/ColumnWithNameAndType.h @@ -1,5 +1,4 @@ -#ifndef DBMS_CORE_COLUMN_WITH_NAME_AND_TYPE_H -#define DBMS_CORE_COLUMN_WITH_NAME_AND_TYPE_H +#pragma once #include @@ -20,8 +19,17 @@ struct ColumnWithNameAndType ColumnPtr column; DataTypePtr type; String name; + + ColumnWithNameAndType cloneEmpty() const + { + ColumnWithNameAndType res; + + res.name = name; + res.type = type; + res.column = column->cloneEmpty(); + + return res; + } }; } - -#endif diff --git a/dbms/include/DB/DataStreams/AggregatingBlockInputStream.h b/dbms/include/DB/DataStreams/AggregatingBlockInputStream.h new file mode 100644 index 00000000000..dc13c37763e --- /dev/null +++ b/dbms/include/DB/DataStreams/AggregatingBlockInputStream.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include + + +namespace DB +{ + +using Poco::SharedPtr; + + +/** Агрегирует поток блоков, используя заданные столбцы-ключи и агрегатные функции. + * Столбцы с агрегатными функциями добавляет в конец блока. + * Агрегатные функции не финализируются, то есть, не заменяются на своё значение, а содержат промежуточное состояние вычислений. + * Это необходимо, чтобы можно было продолжить агрегацию (например, объединяя потоки частично агрегированных данных). + */ +class AggregatingBlockInputStream : public IProfilingBlockInputStream +{ +public: + AggregatingBlockInputStream(BlockInputStreamPtr input_, const ColumnNumbers & keys_, AggregateDescriptions & aggregates_) + : input(input_), keys(keys_), aggregates(aggregates_), aggregator(keys_, aggregates_), has_been_read(false) + { + children.push_back(input); + } + + Block readImpl(); + + String getName() const { return "AggregatingBlockInputStream"; } + +private: + BlockInputStreamPtr input; + const ColumnNumbers keys; + AggregateDescriptions aggregates; + Aggregator aggregator; + bool has_been_read; +}; + +} diff --git a/dbms/include/DB/DataStreams/BlockInputStreamFromRowInputStream.h b/dbms/include/DB/DataStreams/BlockInputStreamFromRowInputStream.h index d3a184e828d..03dcbc173cd 100644 --- a/dbms/include/DB/DataStreams/BlockInputStreamFromRowInputStream.h +++ b/dbms/include/DB/DataStreams/BlockInputStreamFromRowInputStream.h @@ -30,12 +30,9 @@ public: String getName() const { return "BlockInputStreamFromRowInputStream"; } private: - IRowInputStream & row_input; const Block & sample; size_t max_block_size; - - void initBlock(Block & res); }; } diff --git a/dbms/include/DB/DataStreams/FinalizingAggregatedBlockInputStream.h b/dbms/include/DB/DataStreams/FinalizingAggregatedBlockInputStream.h new file mode 100644 index 00000000000..2165e1cc659 --- /dev/null +++ b/dbms/include/DB/DataStreams/FinalizingAggregatedBlockInputStream.h @@ -0,0 +1,59 @@ +#pragma once + +#include + +#include + + +namespace DB +{ + +using Poco::SharedPtr; + + +/** Преобразует агрегатные функции (с промежуточным состоянием) в потоке блоков в конечные значения. + */ +class FinalizingAggregatedBlockInputStream : public IProfilingBlockInputStream +{ +public: + FinalizingAggregatedBlockInputStream(BlockInputStreamPtr input_) + : input(input_) + { + children.push_back(input); + } + + Block readImpl() + { + Block res = input->read(); + + if (!res) + return res; + + size_t rows = res.rows(); + size_t columns = res.columns(); + for (size_t i = 0; i < columns; ++i) + { + ColumnWithNameAndType & column = res.getByPosition(i); + if (ColumnAggregateFunction * col = dynamic_cast(&*column.column)) + { + ColumnAggregateFunction::Container_t & data = col->getData(); + column.type = data[0]->getReturnType(); + ColumnPtr finalized_column = column.type->createColumn(); + + for (size_t j = 0; j < rows; ++j) + finalized_column->insert(data[j]->getResult()); + + column.column = finalized_column; + } + } + + return res; + } + + String getName() const { return "FinalizingAggregatedBlockInputStream"; } + +private: + BlockInputStreamPtr input; +}; + +} diff --git a/dbms/include/DB/Interpreters/Aggregate.h b/dbms/include/DB/Interpreters/Aggregator.h similarity index 61% rename from dbms/include/DB/Interpreters/Aggregate.h rename to dbms/include/DB/Interpreters/Aggregator.h index 057bf3c2368..949b03eea50 100644 --- a/dbms/include/DB/Interpreters/Aggregate.h +++ b/dbms/include/DB/Interpreters/Aggregator.h @@ -22,16 +22,21 @@ typedef std::map AggregatedData; /** Агрегирует поток блоков. */ -class Aggregate +class Aggregator { public: - Aggregate(const ColumnNumbers & keys_, AggregateDescriptions & aggregates_) : keys(keys_), aggregates(aggregates_) {}; + Aggregator(const ColumnNumbers & keys_, AggregateDescriptions & aggregates_) : keys(keys_), aggregates(aggregates_) {}; AggregatedData execute(BlockInputStreamPtr stream); + /// Получить пример блока, описывающего результат. Следует вызывать только после execute. + Block getSampleBlock() { return sample; } + private: ColumnNumbers keys; AggregateDescriptions aggregates; + + Block sample; }; diff --git a/dbms/src/Core/Block.cpp b/dbms/src/Core/Block.cpp index 149521c0e91..2fee00395ee 100644 --- a/dbms/src/Core/Block.cpp +++ b/dbms/src/Core/Block.cpp @@ -181,4 +181,15 @@ std::string Block::dumpNames() const } +Block Block::cloneEmpty() const +{ + Block res; + + for (Container_t::const_iterator it = data.begin(); it != data.end(); ++it) + res.insert(it->cloneEmpty()); + + return res; +} + + } diff --git a/dbms/src/DataStreams/AggregatingBlockInputStream.cpp b/dbms/src/DataStreams/AggregatingBlockInputStream.cpp new file mode 100644 index 00000000000..ada83164327 --- /dev/null +++ b/dbms/src/DataStreams/AggregatingBlockInputStream.cpp @@ -0,0 +1,32 @@ +#include + + +namespace DB +{ + + +Block AggregatingBlockInputStream::readImpl() +{ + if (has_been_read) + return Block(); + + has_been_read = true; + + AggregatedData data = aggregator.execute(input); + Block res = aggregator.getSampleBlock(); + + for (AggregatedData::const_iterator it = data.begin(); it != data.end(); ++it) + { + size_t i = 0; + for (Row::const_iterator jt = it->first.begin(); jt != it->first.end(); ++jt, ++i) + res.getByPosition(i).column->insert(*jt); + + for (AggregateFunctions::const_iterator jt = it->second.begin(); jt != it->second.end(); ++jt, ++i) + res.getByPosition(i).column->insert(*jt); + } + + return res; +} + + +} diff --git a/dbms/src/DataStreams/BlockInputStreamFromRowInputStream.cpp b/dbms/src/DataStreams/BlockInputStreamFromRowInputStream.cpp index 5c1441d1d47..b27d43d3c95 100644 --- a/dbms/src/DataStreams/BlockInputStreamFromRowInputStream.cpp +++ b/dbms/src/DataStreams/BlockInputStreamFromRowInputStream.cpp @@ -19,22 +19,6 @@ BlockInputStreamFromRowInputStream::BlockInputStreamFromRowInputStream( } -void BlockInputStreamFromRowInputStream::initBlock(Block & res) -{ - for (size_t i = 0; i < sample.columns(); ++i) - { - const ColumnWithNameAndType & sample_elem = sample.getByPosition(i); - ColumnWithNameAndType res_elem; - - res_elem.column = sample_elem.column->cloneEmpty(); - res_elem.type = sample_elem.type->clone(); - res_elem.name = sample_elem.name; - - res.insert(res_elem); - } -} - - Block BlockInputStreamFromRowInputStream::readImpl() { Block res; @@ -47,7 +31,7 @@ Block BlockInputStreamFromRowInputStream::readImpl() return res; if (!res) - initBlock(res); + res = sample.cloneEmpty(); if (row.size() != sample.columns()) throw Exception("Number of columns doesn't match", ErrorCodes::NUMBER_OF_COLUMNS_DOESNT_MATCH); diff --git a/dbms/src/DataStreams/tests/aggregating_stream.cpp b/dbms/src/DataStreams/tests/aggregating_stream.cpp new file mode 100644 index 00000000000..64129bf6710 --- /dev/null +++ b/dbms/src/DataStreams/tests/aggregating_stream.cpp @@ -0,0 +1,138 @@ +#include +#include + +#include + +#include + +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include + + +class OneBlockInputStream : public DB::IBlockInputStream +{ +private: + const DB::Block & block; + bool has_been_read; +public: + OneBlockInputStream(const DB::Block & block_) : block(block_), has_been_read(false) {} + + DB::Block read() + { + if (!has_been_read) + { + has_been_read = true; + return block; + } + else + return DB::Block(); + } + + DB::String getName() const { return "OneBlockInputStream"; } +}; + + +int main(int argc, char ** argv) +{ + try + { + size_t n = argc == 2 ? atoi(argv[1]) : 10; + + DB::Block block; + + DB::ColumnWithNameAndType column_x; + column_x.name = "x"; + column_x.type = new DB::DataTypeInt16; + DB::ColumnInt16 * x = new DB::ColumnInt16; + column_x.column = x; + std::vector & vec_x = x->getData(); + + vec_x.resize(n); + for (size_t i = 0; i < n; ++i) + vec_x[i] = i % 9; + + block.insert(column_x); + + const char * strings[] = {"abc", "def", "abcd", "defg", "ac"}; + + DB::ColumnWithNameAndType column_s1; + column_s1.name = "s1"; + column_s1.type = new DB::DataTypeString; + column_s1.column = new DB::ColumnString; + + for (size_t i = 0; i < n; ++i) + column_s1.column->insert(strings[i % 5]); + + block.insert(column_s1); + + DB::ColumnWithNameAndType column_s2; + column_s2.name = "s2"; + column_s2.type = new DB::DataTypeString; + column_s2.column = new DB::ColumnString; + + for (size_t i = 0; i < n; ++i) + column_s2.column->insert(strings[i % 3]); + + block.insert(column_s2); + + DB::ColumnNumbers key_column_numbers; + key_column_numbers.push_back(0); + //key_column_numbers.push_back(1); + + DB::AggregateFunctionFactory factory; + + DB::AggregateDescriptions aggregate_descriptions(1); + aggregate_descriptions[0].function = factory.get("count"); + + Poco::SharedPtr result_types = new DB::DataTypes; + boost::assign::push_back(*result_types) + (new DB::DataTypeInt16) + // (new DB::DataTypeString) + (new DB::DataTypeVarUInt) + ; + + DB::BlockInputStreamPtr stream = new OneBlockInputStream(block); + stream = new DB::AggregatingBlockInputStream(stream, key_column_numbers, aggregate_descriptions); + stream = new DB::FinalizingAggregatedBlockInputStream(stream); + + DB::WriteBufferFromOStream ob(std::cout); + DB::TabSeparatedRowOutputStream out(ob, result_types); + + { + Poco::Stopwatch stopwatch; + stopwatch.start(); + + DB::copyData(*stream, out); + + stopwatch.stop(); + std::cout << std::fixed << std::setprecision(2) + << "Elapsed " << stopwatch.elapsed() / 1000000.0 << " sec." + << ", " << n * 1000000 / stopwatch.elapsed() << " rows/sec." + << std::endl; + } + + std::cout << std::endl; + stream->dumpTree(std::cout); + std::cout << std::endl; + } + catch (const DB::Exception & e) + { + std::cerr << e.message() << std::endl; + } + + return 0; +} diff --git a/dbms/src/Interpreters/Aggregate.cpp b/dbms/src/Interpreters/Aggregator.cpp similarity index 71% rename from dbms/src/Interpreters/Aggregate.cpp rename to dbms/src/Interpreters/Aggregator.cpp index 69d0c052e21..1469734e29c 100644 --- a/dbms/src/Interpreters/Aggregate.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -1,4 +1,7 @@ -#include +#include +#include + +#include namespace DB @@ -10,7 +13,7 @@ namespace DB * Без оптимизации по количеству ключей. * Результат хранится в оперативке и должен полностью помещаться в оперативку. */ -AggregatedData Aggregate::execute(BlockInputStreamPtr stream) +AggregatedData Aggregator::execute(BlockInputStreamPtr stream) { AggregatedData res; @@ -42,6 +45,31 @@ AggregatedData Aggregate::execute(BlockInputStreamPtr stream) for (size_t j = 0; j < aggregate_columns[i].size(); ++j) aggregate_columns[i][j] = block.getByPosition(aggregates[i].arguments[j]).column; + /// Создадим пример блока, описывающего результат + if (!sample) + { + for (size_t i = 0, size = keys_size; i < size; ++i) + sample.insert(block.getByPosition(keys[i]).cloneEmpty()); + + for (size_t i = 0; i < aggregates_size; ++i) + { + ColumnWithNameAndType col; + col.name = aggregates[i].function->getName() + "("; + for (size_t j = 0; j < aggregate_columns[i].size(); ++j) + { + if (j != 0) + col.name += ","; + col.name += block.getByPosition(aggregates[i].arguments[j]).name; + } + col.name += ")"; + + col.type = new DataTypeAggregateFunction; + col.column = new ColumnAggregateFunction; + + sample.insert(col); + } + } + size_t rows = block.rows(); /// Для всех строчек diff --git a/dbms/src/Interpreters/Expression.cpp b/dbms/src/Interpreters/Expression.cpp index 170ee448cea..d074952644a 100644 --- a/dbms/src/Interpreters/Expression.cpp +++ b/dbms/src/Interpreters/Expression.cpp @@ -87,7 +87,10 @@ void Expression::checkTypes(ASTPtr ast) /// Получаем типы результата if (node->aggregate_function) - node->return_types.push_back(node->aggregate_function->getReturnType(argument_types)); + { + node->aggregate_function->setArguments(argument_types); + node->return_types.push_back(node->aggregate_function->getReturnType()); + } else node->return_types = node->function->getReturnTypes(argument_types); } diff --git a/dbms/src/Interpreters/tests/aggregate.cpp b/dbms/src/Interpreters/tests/aggregate.cpp index 2f07d435192..9106b7ab7d6 100644 --- a/dbms/src/Interpreters/tests/aggregate.cpp +++ b/dbms/src/Interpreters/tests/aggregate.cpp @@ -11,7 +11,7 @@ #include -#include +#include #include @@ -94,7 +94,7 @@ int main(int argc, char ** argv) DB::AggregateDescriptions aggregate_descriptions(1); aggregate_descriptions[0].function = factory.get("count"); - DB::Aggregate aggregator(key_column_numbers, aggregate_descriptions); + DB::Aggregator aggregator(key_column_numbers, aggregate_descriptions); { Poco::Stopwatch stopwatch;