Refactor AggregatingSortedAlgorithm.

This commit is contained in:
Nikolai Kochetov 2020-04-13 17:42:58 +03:00
parent 9ce0607de7
commit 377e16c00c
2 changed files with 264 additions and 258 deletions

View File

@ -1,6 +1,5 @@
#include <Processors/Merges/AggregatingSortedAlgorithm.h>
#include <Columns/ColumnAggregateFunction.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>
@ -8,9 +7,71 @@
namespace DB
{
namespace
/// Stores information for aggregation of AggregateFunction columns
struct AggregatingSortedAlgorithm::AggregateDescription
{
AggregatingSortedAlgorithm::ColumnsDefinition defineColumns(
ColumnAggregateFunction * column = nullptr;
const size_t column_number = 0;
AggregateDescription() = default;
explicit AggregateDescription(size_t col_number) : column_number(col_number) {}
};
/// Stores information for aggregation of SimpleAggregateFunction columns
struct AggregatingSortedAlgorithm::SimpleAggregateDescription
{
/// An aggregate function 'anyLast', 'sum'...
AggregateFunctionPtr function;
IAggregateFunction::AddFunc add_function = nullptr;
size_t column_number = 0;
IColumn * column = nullptr;
/// For LowCardinality, convert is converted to nested type. nested_type is nullptr if no conversion needed.
const DataTypePtr nested_type; /// Nested type for LowCardinality, if it is.
const DataTypePtr real_type; /// Type in header.
AlignedBuffer state;
bool created = false;
SimpleAggregateDescription(
AggregateFunctionPtr function_, const size_t column_number_,
DataTypePtr nested_type_, DataTypePtr real_type_)
: function(std::move(function_)), column_number(column_number_)
, nested_type(std::move(nested_type_)), real_type(std::move(real_type_))
{
add_function = function->getAddressOfAddFunction();
state.reset(function->sizeOfData(), function->alignOfData());
}
void createState()
{
if (created)
return;
function->create(state.data());
created = true;
}
void destroyState()
{
if (!created)
return;
function->destroy(state.data());
created = false;
}
/// Explicitly destroy aggregation state if the stream is terminated
~SimpleAggregateDescription()
{
destroyState();
}
SimpleAggregateDescription() = default;
SimpleAggregateDescription(SimpleAggregateDescription &&) = default;
SimpleAggregateDescription(const SimpleAggregateDescription &) = delete;
};
static AggregatingSortedAlgorithm::ColumnsDefinition defineColumns(
const Block & header, const SortDescription & description)
{
AggregatingSortedAlgorithm::ColumnsDefinition def = {};
@ -41,14 +102,14 @@ namespace
continue;
}
if (auto simple_aggr = dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
if (auto simple = dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
{
auto type = recursiveRemoveLowCardinality(column.type);
if (type.get() == column.type.get())
type = nullptr;
// simple aggregate function
AggregatingSortedAlgorithm::SimpleAggregateDescription desc(simple_aggr->getFunction(), i, type, column.type);
AggregatingSortedAlgorithm::SimpleAggregateDescription desc(simple->getFunction(), i, type, column.type);
if (desc.function->allocatesMemoryInArena())
def.allocates_memory_in_arena = true;
@ -64,7 +125,7 @@ namespace
return def;
}
MutableColumns getMergedColumns(const Block & header, const AggregatingSortedAlgorithm::ColumnsDefinition & def)
static MutableColumns getMergedColumns(const Block & header, const AggregatingSortedAlgorithm::ColumnsDefinition & def)
{
MutableColumns columns;
columns.resize(header.columns());
@ -81,8 +142,116 @@ namespace
return columns;
}
static void prepareChunk(Chunk & chunk, const AggregatingSortedAlgorithm::ColumnsDefinition & def)
{
auto num_rows = chunk.getNumRows();
auto columns = chunk.detachColumns();
for (auto & column : columns)
column = column->convertToFullColumnIfConst();
for (auto & desc : def.columns_to_simple_aggregate)
if (desc.nested_type)
columns[desc.column_number] = recursiveRemoveLowCardinality(columns[desc.column_number]);
chunk.setColumns(std::move(columns), num_rows);
}
AggregatingSortedAlgorithm::AggregatingMergedData::AggregatingMergedData(
MutableColumns columns_, UInt64 max_block_size_, ColumnsDefinition & def_)
: MergedData(std::move(columns_), false, max_block_size_), def(def_)
{
initAggregateDescription();
}
void AggregatingSortedAlgorithm::AggregatingMergedData::startGroup(const ColumnRawPtrs & raw_columns, size_t row)
{
/// We will write the data for the group. We copy the values of ordinary columns.
for (auto column_number : def.column_numbers_not_to_aggregate)
columns[column_number]->insertFrom(*raw_columns[column_number], row);
/// Add the empty aggregation state to the aggregate columns. The state will be updated in the `addRow` function.
for (auto & column_to_aggregate : def.columns_to_aggregate)
column_to_aggregate.column->insertDefault();
/// Reset simple aggregation states for next row
for (auto & desc : def.columns_to_simple_aggregate)
desc.createState();
if (def.allocates_memory_in_arena)
arena = std::make_unique<Arena>();
is_group_started = true;
}
void AggregatingSortedAlgorithm::AggregatingMergedData::finishGroup()
{
/// Write the simple aggregation result for the current group.
for (auto & desc : def.columns_to_simple_aggregate)
{
desc.function->insertResultInto(desc.state.data(), *desc.column);
desc.destroyState();
}
is_group_started = false;
++total_merged_rows;
++merged_rows;
/// TODO: sum_blocks_granularity += block_size;
}
void AggregatingSortedAlgorithm::AggregatingMergedData::addRow(SortCursor & cursor)
{
if (!is_group_started)
throw Exception("Can't add a row to the group because it was not started.", ErrorCodes::LOGICAL_ERROR);
for (auto & desc : def.columns_to_aggregate)
desc.column->insertMergeFrom(*cursor->all_columns[desc.column_number], cursor->pos);
for (auto & desc : def.columns_to_simple_aggregate)
{
auto & col = cursor->all_columns[desc.column_number];
desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, arena.get());
}
}
Chunk AggregatingSortedAlgorithm::AggregatingMergedData::pull()
{
if (is_group_started)
throw Exception("Can't pull chunk because group was not finished.", ErrorCodes::LOGICAL_ERROR);
auto chunk = MergedData::pull();
size_t num_rows = chunk.getNumRows();
auto columns_ = chunk.detachColumns();
for (auto & desc : def.columns_to_simple_aggregate)
{
if (desc.nested_type)
{
auto & from_type = desc.nested_type;
auto & to_type = desc.real_type;
columns_[desc.column_number] = recursiveTypeConversion(columns_[desc.column_number], from_type, to_type);
}
}
chunk.setColumns(std::move(columns_), num_rows);
initAggregateDescription();
return chunk;
}
void AggregatingSortedAlgorithm::AggregatingMergedData::initAggregateDescription()
{
for (auto & desc : def.columns_to_simple_aggregate)
desc.column = columns[desc.column_number].get();
for (auto & desc : def.columns_to_aggregate)
desc.column = typeid_cast<ColumnAggregateFunction *>(columns[desc.column_number].get());
}
AggregatingSortedAlgorithm::AggregatingSortedAlgorithm(
const Block & header, size_t num_inputs,
SortDescription description_, size_t max_block_size)
@ -92,33 +261,18 @@ AggregatingSortedAlgorithm::AggregatingSortedAlgorithm(
{
}
void AggregatingSortedAlgorithm::prepareChunk(Chunk & chunk) const
{
auto num_rows = chunk.getNumRows();
auto columns = chunk.detachColumns();
for (auto & column : columns)
column = column->convertToFullColumnIfConst();
for (auto & desc : columns_definition.columns_to_simple_aggregate)
if (desc.nested_type)
columns[desc.column_number] = recursiveRemoveLowCardinality(columns[desc.column_number]);
chunk.setColumns(std::move(columns), num_rows);
}
void AggregatingSortedAlgorithm::initialize(Chunks chunks)
{
for (auto & chunk : chunks)
if (chunk)
prepareChunk(chunk);
prepareChunk(chunk, columns_definition);
initializeQueue(std::move(chunks));
}
void AggregatingSortedAlgorithm::consume(Chunk chunk, size_t source_num)
{
prepareChunk(chunk);
prepareChunk(chunk, columns_definition);
updateCursor(std::move(chunk), source_num);
}
@ -128,18 +282,13 @@ IMergingAlgorithm::Status AggregatingSortedAlgorithm::merge()
while (queue.isValid())
{
bool key_differs;
bool has_previous_group = !last_key.empty();
SortCursor current = queue.current();
{
detail::RowRef current_key;
current_key.set(current);
if (!has_previous_group) /// The first key encountered.
key_differs = true;
else
key_differs = !last_key.hasEqualSortColumnsWith(current_key);
key_differs = last_key.empty() || !last_key.hasEqualSortColumnsWith(current_key);
last_key = current_key;
last_chunk_sort_columns.clear();
@ -147,37 +296,20 @@ IMergingAlgorithm::Status AggregatingSortedAlgorithm::merge()
if (key_differs)
{
/// Write the simple aggregation result for the previous group.
if (merged_data.isGroupStarted())
{
insertSimpleAggregationResult();
merged_data.insertRow();
}
merged_data.finishGroup();
/// if there are enough rows accumulated and the last one is calculated completely
if (merged_data.hasEnoughRows())
{
last_key.reset();
Status(merged_data.pull(columns_definition));
Status(merged_data.pull());
}
/// We will write the data for the group. We copy the values of ordinary columns.
merged_data.initializeRow(current->all_columns, current->pos,
columns_definition.column_numbers_not_to_aggregate);
/// Add the empty aggregation state to the aggregate columns. The state will be updated in the `addRow` function.
for (auto & column_to_aggregate : columns_definition.columns_to_aggregate)
column_to_aggregate.column->insertDefault();
/// Reset simple aggregation states for next row
for (auto & desc : columns_definition.columns_to_simple_aggregate)
desc.createState();
if (columns_definition.allocates_memory_in_arena)
arena = std::make_unique<Arena>();
merged_data.startGroup(current->all_columns, current->pos);
}
addRow(current);
merged_data.addRow(current);
if (!current->isLast())
{
@ -193,35 +325,10 @@ IMergingAlgorithm::Status AggregatingSortedAlgorithm::merge()
/// Write the simple aggregation result for the previous group.
if (merged_data.isGroupStarted())
{
insertSimpleAggregationResult();
merged_data.insertRow();
}
merged_data.finishGroup();
last_chunk_sort_columns.clear();
return Status(merged_data.pull(columns_definition), true);
return Status(merged_data.pull(), true);
}
void AggregatingSortedAlgorithm::addRow(SortCursor & cursor)
{
for (auto & desc : columns_definition.columns_to_aggregate)
desc.column->insertMergeFrom(*cursor->all_columns[desc.column_number], cursor->pos);
for (auto & desc : columns_definition.columns_to_simple_aggregate)
{
auto & col = cursor->all_columns[desc.column_number];
desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, arena.get());
}
}
void AggregatingSortedAlgorithm::insertSimpleAggregationResult()
{
for (auto & desc : columns_definition.columns_to_simple_aggregate)
{
desc.function->insertResultInto(desc.state.data(), *desc.column);
desc.destroyState();
}
}
}

View File

@ -6,6 +6,7 @@
#include <Common/AlignedBuffer.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/Arena.h>
namespace DB
{
@ -22,18 +23,14 @@ public:
Status merge() override;
struct SimpleAggregateDescription;
struct AggregateDescription;
/// This structure define columns into one of three types:
/// * columns which are not aggregate functions and not needed to be aggregated
/// * usual aggregate functions, which stores states into ColumnAggregateFunction
/// * simple aggregate functions, which store states into ordinary columns
struct ColumnsDefinition
{
struct AggregateDescription
{
ColumnAggregateFunction * column = nullptr;
const size_t column_number = 0;
AggregateDescription() = default;
explicit AggregateDescription(size_t col_number) : column_number(col_number) {}
};
/// Columns with which numbers should not be aggregated.
ColumnNumbers column_numbers_not_to_aggregate;
std::vector<AggregateDescription> columns_to_aggregate;
@ -47,135 +44,37 @@ private:
/// Specialization for AggregatingSortedAlgorithm.
struct AggregatingMergedData : public MergedData
{
private:
using MergedData::pull;
using MergedData::insertRow;
public:
AggregatingMergedData(MutableColumns columns_, UInt64 max_block_size_, ColumnsDefinition & def)
: MergedData(std::move(columns_), false, max_block_size_)
{
initAggregateDescription(def);
}
AggregatingMergedData(MutableColumns columns_, UInt64 max_block_size_, ColumnsDefinition & def_);
void initializeRow(const ColumnRawPtrs & raw_columns, size_t row, const ColumnNumbers & column_numbers)
{
for (auto column_number : column_numbers)
columns[column_number]->insertFrom(*raw_columns[column_number], row);
is_group_started = true;
}
void startGroup(const ColumnRawPtrs & raw_columns, size_t row);
void finishGroup();
bool isGroupStarted() const { return is_group_started; }
void addRow(SortCursor & cursor);
void insertRow()
{
is_group_started = false;
++total_merged_rows;
++merged_rows;
/// TODO: sum_blocks_granularity += block_size;
}
Chunk pull(ColumnsDefinition & def)
{
auto chunk = pull();
size_t num_rows = chunk.getNumRows();
auto columns_ = chunk.detachColumns();
for (auto & desc : def.columns_to_simple_aggregate)
{
if (desc.nested_type)
{
auto & from_type = desc.nested_type;
auto & to_type = desc.real_type;
columns_[desc.column_number] = recursiveTypeConversion(columns_[desc.column_number], from_type, to_type);
}
}
chunk.setColumns(std::move(columns_), num_rows);
initAggregateDescription(def);
return chunk;
}
Chunk pull();
private:
bool is_group_started = false;
/// Initialize aggregate descriptions with columns.
void initAggregateDescription(ColumnsDefinition & def)
{
for (auto & desc : def.columns_to_simple_aggregate)
desc.column = columns[desc.column_number].get();
for (auto & desc : def.columns_to_aggregate)
desc.column = typeid_cast<ColumnAggregateFunction *>(columns[desc.column_number].get());
}
using MergedData::pull;
};
ColumnsDefinition columns_definition;
AggregatingMergedData merged_data;
ColumnsDefinition & def;
/// Memory pool for SimpleAggregateFunction
/// (only when allocates_memory_in_arena == true).
std::unique_ptr<Arena> arena;
void prepareChunk(Chunk & chunk) const;
void addRow(SortCursor & cursor);
void insertSimpleAggregationResult();
bool is_group_started = false;
public:
/// Stores information for aggregation of SimpleAggregateFunction columns
struct SimpleAggregateDescription
{
/// An aggregate function 'anyLast', 'sum'...
AggregateFunctionPtr function;
IAggregateFunction::AddFunc add_function = nullptr;
size_t column_number = 0;
IColumn * column = nullptr;
/// For LowCardinality, convert is converted to nested type. nested_type is nullptr if no conversion needed.
const DataTypePtr nested_type; /// Nested type for LowCardinality, if it is.
const DataTypePtr real_type; /// Type in header.
AlignedBuffer state;
bool created = false;
SimpleAggregateDescription(
AggregateFunctionPtr function_, const size_t column_number_,
DataTypePtr nested_type_, DataTypePtr real_type_)
: function(std::move(function_)), column_number(column_number_)
, nested_type(std::move(nested_type_)), real_type(std::move(real_type_))
{
add_function = function->getAddressOfAddFunction();
state.reset(function->sizeOfData(), function->alignOfData());
}
void createState()
{
if (created)
return;
function->create(state.data());
created = true;
}
void destroyState()
{
if (!created)
return;
function->destroy(state.data());
created = false;
}
/// Explicitly destroy aggregation state if the stream is terminated
~SimpleAggregateDescription()
{
destroyState();
}
SimpleAggregateDescription() = default;
SimpleAggregateDescription(SimpleAggregateDescription &&) = default;
SimpleAggregateDescription(const SimpleAggregateDescription &) = delete;
/// Initialize aggregate descriptions with columns.
void initAggregateDescription();
};
/// Order between members is important because merged_data has reference to columns_definition.
ColumnsDefinition columns_definition;
AggregatingMergedData merged_data;
};
}