Support partial result for aggregating transform during query execution

This commit is contained in:
alexX512 2023-08-07 15:58:14 +00:00
parent de62239a5d
commit 080b4badbd
23 changed files with 332 additions and 38 deletions

View File

@ -2273,6 +2273,29 @@ Block Aggregator::prepareBlockAndFillWithoutKey(AggregatedDataVariants & data_va
return block;
}
Block Aggregator::prepareBlockAndFillWithoutKeySnapshot(AggregatedDataVariants & data_variants) const
{
size_t rows = 1;
bool final = true;
auto && out_cols
= prepareOutputBlockColumns(params, aggregate_functions, getHeader(final), data_variants.aggregates_pools, final, rows);
auto && [key_columns, raw_key_columns, aggregate_columns, final_aggregate_columns, aggregate_columns_data] = out_cols;
AggregatedDataWithoutKey & data = data_variants.without_key;
/// Always single-thread. It's safe to pass current arena from 'aggregates_pool'.
for (size_t insert_i = 0; insert_i < params.aggregates_size; ++insert_i)
aggregate_functions[insert_i]->insertResultInto(
data + offsets_of_aggregate_states[insert_i],
*final_aggregate_columns[insert_i],
data_variants.aggregates_pool);
Block block = finalizeBlock(params, getHeader(final), std::move(out_cols), final, rows);
return block;
}
template <bool return_single_block>
Aggregator::ConvertToBlockRes<return_single_block>
Aggregator::prepareBlockAndFillSingleLevel(AggregatedDataVariants & data_variants, bool final) const

View File

@ -1210,6 +1210,7 @@ private:
friend class ConvertingAggregatedToChunksSource;
friend class ConvertingAggregatedToChunksWithMergingSource;
friend class AggregatingInOrderTransform;
friend class AggregatingPartialResultTransform;
/// Data structure of source blocks.
Block header;
@ -1391,6 +1392,7 @@ private:
std::atomic<bool> * is_cancelled = nullptr) const;
Block prepareBlockAndFillWithoutKey(AggregatedDataVariants & data_variants, bool final, bool is_overflows) const;
Block prepareBlockAndFillWithoutKeySnapshot(AggregatedDataVariants & data_variants) const;
BlocksList prepareBlocksAndFillTwoLevel(AggregatedDataVariants & data_variants, bool final, ThreadPool * thread_pool) const;
template <bool return_single_block>

View File

@ -40,5 +40,10 @@ std::string IProcessor::statusToName(Status status)
UNREACHABLE();
}
ProcessorPtr IProcessor::getPartialResultProcessorPtr(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms)
{
return current_processor->getPartialResultProcessor(current_processor, partial_result_limit, partial_result_duration_ms);
}
}

View File

@ -164,6 +164,8 @@ public:
static std::string statusToName(Status status);
static ProcessorPtr getPartialResultProcessorPtr(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms);
/** Method 'prepare' is responsible for all cheap ("instantaneous": O(1) of data volume, no wait) calculations.
*
* It may access input and output ports,
@ -238,11 +240,6 @@ public:
virtual bool isPartialResultProcessor() const { return false; }
virtual bool supportPartialResultProcessor() const { return false; }
virtual ProcessorPtr getPartialResultProcessor(const ProcessorPtr & /*current_processor*/, UInt64 /*partial_result_limit*/, UInt64 /*partial_result_duration_ms*/)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method 'getPartialResultProcessor' is not implemented for {} processor", getName());
}
/// In case if query was cancelled executor will wait till all processors finish their jobs.
/// Generally, there is no reason to check this flag. However, it may be reasonable for long operations (e.g. i/o).
bool isCancelled() const { return is_cancelled.load(std::memory_order_acquire); }
@ -377,6 +374,11 @@ public:
protected:
virtual void onCancel() {}
virtual ProcessorPtr getPartialResultProcessor(const ProcessorPtr & /*current_processor*/, UInt64 /*partial_result_limit*/, UInt64 /*partial_result_duration_ms*/)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method 'getPartialResultProcessor' is not implemented for {} processor", getName());
}
private:
/// For:
/// - elapsed_us

View File

@ -55,7 +55,6 @@ private:
ColumnRawPtrs extractSortColumns(const Columns & columns) const;
bool sortColumnsEqualAt(const ColumnRawPtrs & current_chunk_sort_columns, UInt64 current_chunk_row_num) const;
bool supportPartialResultProcessor() const override { return true; }
ProcessorPtr getPartialResultProcessor(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms) override;
public:
@ -76,6 +75,8 @@ public:
void setRowsBeforeLimitCounter(RowsBeforeLimitCounterPtr counter) override { rows_before_limit_at_least.swap(counter); }
void setInputPortHasCounter(size_t pos) { ports_data[pos].input_port_has_counter = true; }
bool supportPartialResultProcessor() const override { return true; }
};
}

View File

@ -0,0 +1,42 @@
#include <Processors/Transforms/AggregatingPartialResultTransform.h>
namespace DB
{
AggregatingPartialResultTransform::AggregatingPartialResultTransform(
const Block & input_header, const Block & output_header, AggregatingTransformPtr aggregating_transform_,
UInt64 partial_result_limit_, UInt64 partial_result_duration_ms_)
: PartialResultTransform(input_header, output_header, partial_result_limit_, partial_result_duration_ms_)
, aggregating_transform(std::move(aggregating_transform_))
{}
PartialResultTransform::ShaphotResult AggregatingPartialResultTransform::getRealProcessorSnapshot()
{
std::lock_guard lock(aggregating_transform->snapshot_mutex);
auto & params = aggregating_transform->params;
/// Currently not supported cases
/// TODO: check that insert results from prepareBlockAndFillWithoutKey return values without changing of the aggregator state
if (params->params.keys_size != 0 /// has at least one key for aggregation
|| params->aggregator.hasTemporaryData() /// use external storage for aggregation
|| aggregating_transform->many_data->variants.size() > 1) /// use more then one stream for aggregation
return {{}, SnaphotStatus::Stopped};
if (aggregating_transform->is_generate_initialized)
return {{}, SnaphotStatus::Stopped};
if (aggregating_transform->variants.empty())
return {{}, SnaphotStatus::NotReady};
auto & aggregator = params->aggregator;
auto prepared_data = aggregator.prepareVariantsToMerge(aggregating_transform->many_data->variants);
AggregatedDataVariantsPtr & first = prepared_data.at(0);
aggregator.mergeWithoutKeyDataImpl(prepared_data);
auto block = aggregator.prepareBlockAndFillWithoutKeySnapshot(*first);
return {convertToChunk(block), SnaphotStatus::Ready};
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Processors/Transforms/AggregatingTransform.h>
#include <Processors/Transforms/PartialResultTransform.h>
namespace DB
{
class AggregatingPartialResultTransform : public PartialResultTransform
{
public:
using AggregatingTransformPtr = std::shared_ptr<AggregatingTransform>;
AggregatingPartialResultTransform(
const Block & input_header, const Block & output_header, AggregatingTransformPtr aggregating_transform_,
UInt64 partial_result_limit_, UInt64 partial_result_duration_ms_);
String getName() const override { return "AggregatingPartialResultTransform"; }
ShaphotResult getRealProcessorSnapshot() override;
private:
AggregatingTransformPtr aggregating_transform;
};
}

View File

@ -1,3 +1,4 @@
#include <Processors/Transforms/AggregatingPartialResultTransform.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <Formats/NativeReader.h>
@ -657,6 +658,8 @@ void AggregatingTransform::consume(Chunk chunk)
src_rows += num_rows;
src_bytes += chunk.bytes();
std::lock_guard lock(snapshot_mutex);
if (params->params.only_merge)
{
auto block = getInputs().front().getHeader().cloneWithColumns(chunk.detachColumns());
@ -676,6 +679,7 @@ void AggregatingTransform::initGenerate()
if (is_generate_initialized)
return;
std::lock_guard lock(snapshot_mutex);
is_generate_initialized = true;
/// If there was no data, and we aggregate without keys, and we must return single row with the result of empty aggregation.
@ -806,4 +810,12 @@ void AggregatingTransform::initGenerate()
}
}
ProcessorPtr AggregatingTransform::getPartialResultProcessor(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms)
{
const auto & input_header = inputs.front().getHeader();
const auto & output_header = outputs.front().getHeader();
auto aggregating_processor = std::dynamic_pointer_cast<AggregatingTransform>(current_processor);
return std::make_shared<AggregatingPartialResultTransform>(input_header, output_header, std::move(aggregating_processor), partial_result_limit, partial_result_duration_ms);
}
}

View File

@ -170,9 +170,13 @@ public:
void work() override;
Processors expandPipeline() override;
bool supportPartialResultProcessor() const override { return true; }
protected:
void consume(Chunk chunk);
ProcessorPtr getPartialResultProcessor(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms) override;
private:
/// To read the data that was flushed into the temporary data file.
Processors processors;
@ -212,6 +216,9 @@ private:
bool is_consume_started = false;
friend class AggregatingPartialResultTransform;
std::mutex snapshot_mutex;
void initGenerate();
};

View File

@ -26,13 +26,15 @@ public:
static Block transformHeader(Block header, const ActionsDAG & expression);
bool supportPartialResultProcessor() const override { return true; }
protected:
void transform(Chunk & chunk) override;
bool supportPartialResultProcessor() const override { return true; }
ProcessorPtr getPartialResultProcessor(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms) override;
private:
ExpressionActionsPtr expression;
};

View File

@ -33,10 +33,11 @@ public:
void setQuota(const std::shared_ptr<const EnabledQuota> & quota_) { quota = quota_; }
bool supportPartialResultProcessor() const override { return true; }
protected:
void transform(Chunk & chunk) override;
bool supportPartialResultProcessor() const override { return true; }
ProcessorPtr getPartialResultProcessor(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms) override;
private:

View File

@ -1,4 +1,3 @@
#include <Processors/Transforms/MergeSortingTransform.h>
#include <Processors/Transforms/MergeSortingPartialResultTransform.h>
namespace DB

View File

@ -1,5 +1,6 @@
#pragma once
#include <Processors/Transforms/MergeSortingTransform.h>
#include <Processors/Transforms/PartialResultTransform.h>
namespace DB

View File

@ -285,17 +285,6 @@ void MergeSortingTransform::remerge()
ProcessorPtr MergeSortingTransform::getPartialResultProcessor(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms)
{
if (getName() != current_processor->getName() || current_processor.get() != this)
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"To create partial result processor variable current_processor should use " \
"the same class and pointer as in the original processor with class {} and pointer {}. " \
"But current_processor has another class {} or pointer {} then original.",
getName(),
static_cast<void*>(this),
current_processor->getName(),
static_cast<void*>(current_processor.get()));
const auto & header = inputs.front().getHeader();
auto merge_sorting_processor = std::dynamic_pointer_cast<MergeSortingTransform>(current_processor);
return std::make_shared<MergeSortingPartialResultTransform>(header, std::move(merge_sorting_processor), partial_result_limit, partial_result_duration_ms);

View File

@ -33,6 +33,8 @@ public:
String getName() const override { return "MergeSortingTransform"; }
bool supportPartialResultProcessor() const override { return true; }
protected:
void consume(Chunk chunk) override;
void serialize() override;
@ -40,7 +42,6 @@ protected:
Processors expandPipeline() override;
bool supportPartialResultProcessor() const override { return true; }
ProcessorPtr getPartialResultProcessor(const ProcessorPtr & current_processor, UInt64 partial_result_limit, UInt64 partial_result_duration_ms) override;
private:
@ -61,10 +62,10 @@ private:
/// Merge all accumulated blocks to keep no more than limit rows.
void remerge();
ProcessorPtr external_merging_sorted;
friend class MergeSortingPartialResultTransform;
std::mutex snapshot_mutex;
ProcessorPtr external_merging_sorted;
};
}

View File

@ -3,8 +3,12 @@
namespace DB
{
PartialResultTransform::PartialResultTransform(const Block & header, UInt64 partial_result_limit_, UInt64 partial_result_duration_ms_)
: IProcessor({header}, {header})
: PartialResultTransform(header, header, partial_result_limit_, partial_result_duration_ms_) {}
PartialResultTransform::PartialResultTransform(const Block & input_header, const Block & output_header, UInt64 partial_result_limit_, UInt64 partial_result_duration_ms_)
: IProcessor({input_header}, {output_header})
, input(inputs.front())
, output(outputs.front())
, partial_result_limit(partial_result_limit_)

View File

@ -9,6 +9,7 @@ class PartialResultTransform : public IProcessor
{
public:
PartialResultTransform(const Block & header, UInt64 partial_result_limit_, UInt64 partial_result_duration_ms_);
PartialResultTransform(const Block & input_header, const Block & output_header, UInt64 partial_result_limit_, UInt64 partial_result_duration_ms_);
String getName() const override { return "PartialResultTransform"; }

View File

@ -636,7 +636,7 @@ void Pipe::addPartialResultSimpleTransform(const ProcessorPtr & transform, size_
return;
}
auto partial_result_transform = transform->getPartialResultProcessor(transform, partial_result_limit, partial_result_duration_ms);
auto partial_result_transform = IProcessor::getPartialResultProcessorPtr(transform, partial_result_limit, partial_result_duration_ms);
connectPartialResultPort(partial_result_port, partial_result_transform->getInputs().front());
@ -661,7 +661,7 @@ void Pipe::addPartialResultTransform(const ProcessorPtr & transform)
return;
}
auto partial_result_transform = transform->getPartialResultProcessor(transform, partial_result_limit, partial_result_duration_ms);
auto partial_result_transform = IProcessor::getPartialResultProcessorPtr(transform, partial_result_limit, partial_result_duration_ms);
auto & inputs = partial_result_transform->getInputs();
if (inputs.size() != partial_result_ports.size())

View File

@ -8,13 +8,13 @@ import sys
CURDIR = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CURDIR, "helpers"))
from tcp_client import TCPClient, assertPacket
from tcp_client import TCPClient
def main():
with TCPClient() as client:
client.sendQuery(
f"SELECT number FROM numbers_mt(1e7+1) ORDER BY -number LIMIT 15 SETTINGS max_threads = 1, partial_result_update_duration_ms=1, max_rows_in_partial_result=10"
"SELECT number FROM numbers_mt(1e7+1) ORDER BY -number LIMIT 15 SETTINGS max_threads = 1, partial_result_update_duration_ms = 1, max_rows_in_partial_result = 10"
)
# external tables
@ -23,13 +23,13 @@ def main():
# Partial result
_, partial_result = client.readDataWithoutProgress()[0]
assert_message = (
"There should be at least one block of data with partial result"
)
assert len(partial_result) > 0, assert_message
assert len(partial_result) > 0, "Expected at least one block with a non-empty partial result before getting the full result"
while True:
assert all(
a >= b for a, b in zip(partial_result, partial_result[1:])
), "Partial result always should be sorted for this test"
_, new_partial_result = client.readDataWithoutProgress(
need_print_info=False
)[0]
@ -37,15 +37,22 @@ def main():
break
data_size = len(partial_result)
assert_message = f"New block contains more info about the full data so sorted results should not be less then in the previous iteration. New result {new_partial_result}. Previous result {partial_result}"
assert all(
partial_result[i] <= new_partial_result[i] for i in range(data_size)
), assert_message
), f"New partial result values should always be greater then old one because a new block contains more information about the full data. New result {new_partial_result}. Previous result {partial_result}"
partial_result = new_partial_result
# Full result
_, full_result = client.readDataWithoutProgress()[0]
data_size = len(partial_result)
assert all(
partial_result[i] <= full_result[i] for i in range(data_size)
), f"Full result values should always be greater then partial result values. Full result {full_result}. Partial result {partial_result}"
for result in full_result:
print(result)

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
CURDIR = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CURDIR, "helpers"))
from tcp_client import TCPClient, assertPacket
def run_query_without_errors(query, support_partial_result, invariants=None):
if invariants is None:
invariants = {}
with TCPClient() as client:
client.sendQuery(query)
# external tables
client.sendEmptyBlock()
client.readHeader()
# Partial result
partial_results = client.readDataWithoutProgress()
if support_partial_result:
assert len(partial_results[0][1]) > 0, "Expected at least one block with a non-empty partial result before getting the full result"
while True:
new_partial_results = client.readDataWithoutProgress(need_print_info=False)
if len(new_partial_results[0][1]) == 0:
break
for new_result, old_result in zip(new_partial_results, partial_results):
assert new_result[0] == old_result[0], "Keys in blocks should be in the same order"
key = new_result[0]
if key in invariants:
old_value = old_result[1]
new_value = new_result[1]
assert invariants[key](old_value, new_value), f"Problem with the invariant between old and new versions of a partial result for key: {key}. Old value {old_value}, new value {new_value}"
else:
assert len(partial_results[0][1]) == 0, "Expected no non-empty partial result blocks before getting the full result"
# Full result
full_results = client.readDataWithoutProgress()
if support_partial_result:
for full_result, partial_result in zip(full_results, partial_results):
assert full_result[0] == partial_result[0], "Keys in blocks should be in the same order"
key = full_result[0]
if key in invariants:
full_value = full_result[1]
partial_value = partial_result[1]
assert invariants[key](partial_value, full_value), f"Problem with the invariant between full and partial result for key: {key}. Partial value {partial_value}. Full value {full_value}"
for key, value in full_results:
if isinstance(value[0], int):
print(key, value)
def supported_scenarios():
query = "select median(number), stddevSamp(number), stddevPop(number), max(number), min(number), any(number), count(number), avg(number), sum(number) from numbers_mt(1e7+1) settings max_threads = 1, partial_result_update_duration_ms = 1"
invariants = {
"median(number)": lambda old_value, new_value: old_value <= new_value,
"max(number)": lambda old_value, new_value: old_value <= new_value,
"min(number)": lambda old_value, new_value: old_value >= new_value,
"count(number)": lambda old_value, new_value: old_value <= new_value,
"avg(number)": lambda old_value, new_value: old_value <= new_value,
"sum(number)": lambda old_value, new_value: old_value <= new_value,
}
run_query_without_errors(query, support_partial_result=True, invariants=invariants)
def unsupported_scenarios():
# Currently aggregator for partial result supports only single thread aggregation without key
# Update test when multithreading or aggregation with GROUP BY will be supported for partial result updates
multithread_query = "select sum(number) from numbers_mt(1e7+1) settings max_threads = 2, partial_result_update_duration_ms = 1"
run_query_without_errors(multithread_query, support_partial_result=False)
group_with_key_query = "select mod2, sum(number) from numbers_mt(1e7+1) group by number % 2 as mod2 settings max_threads = 1, partial_result_update_duration_ms = 1"
run_query_without_errors(group_with_key_query, support_partial_result=False)
def main():
supported_scenarios()
unsupported_scenarios()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
Rows 0 Columns 9
Column median(number) type Float64
Column stddevSamp(number) type Float64
Column stddevPop(number) type Float64
Column max(number) type UInt64
Column min(number) type UInt64
Column any(number) type UInt64
Column count(number) type UInt64
Column avg(number) type Float64
Column sum(number) type UInt64
Rows 1 Columns 9
Column median(number) type Float64
Column stddevSamp(number) type Float64
Column stddevPop(number) type Float64
Column max(number) type UInt64
Column min(number) type UInt64
Column any(number) type UInt64
Column count(number) type UInt64
Column avg(number) type Float64
Column sum(number) type UInt64
Rows 1 Columns 9
Column median(number) type Float64
Column stddevSamp(number) type Float64
Column stddevPop(number) type Float64
Column max(number) type UInt64
Column min(number) type UInt64
Column any(number) type UInt64
Column count(number) type UInt64
Column avg(number) type Float64
Column sum(number) type UInt64
max(number) [10000000]
min(number) [0]
any(number) [0]
count(number) [10000001]
sum(number) [50000005000000]
Rows 0 Columns 1
Column sum(number) type UInt64
Rows 0 Columns 1
Column sum(number) type UInt64
Rows 1 Columns 1
Column sum(number) type UInt64
sum(number) [50000005000000]
Rows 0 Columns 2
Column mod2 type UInt8
Column sum(number) type UInt64
Rows 0 Columns 2
Column mod2 type UInt8
Column sum(number) type UInt64
Rows 2 Columns 2
Column mod2 type UInt8
Column sum(number) type UInt64
mod2 [0, 1]
sum(number) [25000005000000, 25000000000000]

View File

@ -0,0 +1,8 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
# We should have correct env vars from shell_config.sh to run this test
python3 "$CURDIR"/02834_partial_aggregating_result_during_query_execution.python

View File

@ -1,6 +1,7 @@
import socket
import os
import uuid
import struct
CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1")
CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000"))
@ -105,6 +106,15 @@ class TCPClient(object):
def readUInt64(self):
return self.readUInt(8)
def readFloat16(self):
return struct.unpack("e", self.readStrict(2))
def readFloat32(self):
return struct.unpack("f", self.readStrict(4))
def readFloat64(self):
return struct.unpack("d", self.readStrict(8))
def readVarUInt(self):
x = 0
for i in range(9):
@ -250,12 +260,22 @@ class TCPClient(object):
print("Column {} type {}".format(col_name, type_name))
def readRow(self, row_type, rows):
if row_type == "UInt64":
row = [self.readUInt64() for _ in range(rows)]
supported_row_types = {
"UInt8": self.readUInt8,
"UInt16": self.readUInt16,
"UInt32": self.readUInt32,
"UInt64": self.readUInt64,
"Float16": self.readFloat16,
"Float32": self.readFloat32,
"Float64": self.readFloat64,
}
if row_type in supported_row_types:
read_type = supported_row_types[row_type]
row = [read_type() for _ in range(rows)]
return row
else:
raise RuntimeError(
"Currently python version of tcp client doesn't support the following type of row: {}".format(
"Current python version of tcp client doesn't support the following type of row: {}".format(
row_type
)
)