ClickHouse/dbms/include/DB/AggregateFunctions/AggregateFunctionArray.h

132 lines
3.8 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#pragma once
#include <DB/Columns/ColumnArray.h>
#include <DB/DataTypes/DataTypeArray.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
#include <DB/AggregateFunctions/IAggregateFunction.h>
namespace DB
{
/** Не агрегатная функция, а адаптер агрегатных функций,
* который любую агрегатную функцию agg(x) делает агрегатной функцией вида aggArray(x).
* Адаптированная агрегатная функция вычисляет вложенную агрегатную функцию для каждого элемента массива.
*/
class AggregateFunctionArray final : public IAggregateFunction
{
private:
AggregateFunctionPtr nested_func_owner;
IAggregateFunction * nested_func;
size_t num_agruments;
public:
AggregateFunctionArray(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {}
String getName() const override
{
return nested_func->getName() + "Array";
}
DataTypePtr getReturnType() const override
{
return nested_func->getReturnType();
}
void setArguments(const DataTypes & arguments) override
{
num_agruments = arguments.size();
if (0 == num_agruments)
throw Exception("Array aggregate functions requires at least one argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
DataTypes nested_arguments;
for (size_t i = 0; i < num_agruments; ++i)
{
if (const DataTypeArray * array = typeid_cast<const DataTypeArray *>(&*arguments[i]))
nested_arguments.push_back(array->getNestedType());
else
throw Exception("Illegal type " + arguments[i]->getName() + " of argument #" + toString(i + 1) + " for aggregate function " + getName() + ". Must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
nested_func->setArguments(nested_arguments);
}
void setParameters(const Array & params) override
{
nested_func->setParameters(params);
}
void create(AggregateDataPtr place) const override
{
nested_func->create(place);
}
void destroy(AggregateDataPtr place) const noexcept override
{
nested_func->destroy(place);
}
bool hasTrivialDestructor() const override
{
return nested_func->hasTrivialDestructor();
}
size_t sizeOfData() const override
{
return nested_func->sizeOfData();
}
size_t alignOfData() const override
{
return nested_func->alignOfData();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const IColumn * nested[num_agruments];
for (size_t i = 0; i < num_agruments; ++i)
nested[i] = &static_cast<const ColumnArray &>(*columns[i]).getData();
const ColumnArray & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const IColumn::Offsets_t & offsets = first_array_column.getOffsets();
size_t begin = row_num == 0 ? 0 : offsets[row_num - 1];
size_t end = offsets[row_num];
for (size_t i = begin; i < end; ++i)
nested_func->add(place, nested, i, nullptr);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
nested_func->merge(place, rhs, arena);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
nested_func->serialize(place, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
nested_func->deserialize(place, buf, arena);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
nested_func->insertResultInto(place, to);
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
{
static_cast<const AggregateFunctionArray &>(*that).add(place, columns, row_num, arena);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
};
}