rewrite with better code

This commit is contained in:
Amos Bird 2020-12-09 12:44:25 +08:00 committed by Nikita Mikhailov
parent cc75c33064
commit c19ce7dad3
3 changed files with 704 additions and 1004 deletions

View File

@ -1,11 +1,10 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include "registerAggregateFunctions.h"
// TODO include this last because of a broken roaring header. See the comment
// inside.
// TODO include this last because of a broken roaring header. See the comment inside.
#include <AggregateFunctions/AggregateFunctionGroupBitmap.h>
namespace DB
@ -17,46 +16,54 @@ namespace ErrorCodes
namespace
{
template <template <typename> class Data>
AggregateFunctionPtr createAggregateFunctionBitmap(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
template <template <typename> class Data>
AggregateFunctionPtr createAggregateFunctionBitmap(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
if (!argument_types[0]->canBeUsedInBitOperations())
throw Exception(
"The type " + argument_types[0]->getName() + " of argument for aggregate function " + name
+ " is illegal, because it cannot be used in Bitmap operations",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!argument_types[0]->canBeUsedInBitOperations())
throw Exception("The type " + argument_types[0]->getName() + " of argument for aggregate function " + name
+ " is illegal, because it cannot be used in Bitmap operations",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionBitmap, Data>(*argument_types[0], argument_types[0]));
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionBitmap, Data>(*argument_types[0], argument_types[0]));
if (!res)
throw Exception(
"Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
return res;
}
// Additional aggregate functions to manipulate bitmaps.
template <template <typename, typename> class AggregateFunctionTemplate>
AggregateFunctionPtr
createAggregateFunctionBitmapL2(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
DataTypePtr argument_type_ptr = argument_types[0];
WhichDataType which(*argument_type_ptr);
if (which.idx != TypeIndex::AggregateFunction)
throw Exception(
"Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
template <template <typename, typename> class AggregateFunctionTemplate>
AggregateFunctionPtr createAggregateFunctionBitmapL2(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
DataTypePtr argument_type_ptr = argument_types[0];
WhichDataType which(*argument_type_ptr);
if (which.idx != TypeIndex::AggregateFunction)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypeAggregateFunction& datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction&>(*argument_type_ptr);
AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
argument_type_ptr = aggfunc->getArgumentTypes()[0];
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionTemplate, AggregateFunctionGroupBitmapData>(*argument_type_ptr, argument_type_ptr));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr);
AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
argument_type_ptr = aggfunc->getArgumentTypes()[0];
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionTemplate, AggregateFunctionGroupBitmapData>(
*argument_type_ptr, argument_type_ptr));
if (!res)
throw Exception(
"Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
}
void registerAggregateFunctionsBitmap(AggregateFunctionFactory & factory)

View File

@ -1,31 +1,26 @@
#pragma once
#include <Columns/ColumnVector.h>
#include <Common/assert_cast.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Common/assert_cast.h>
// TODO include this last because of a broken roaring header. See the comment inside.
#include <AggregateFunctions/AggregateFunctionGroupBitmapData.h>
namespace DB
{
/// Counts bitmap operation on numbers.
template <typename T, typename Data>
class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>
{
public:
AggregateFunctionBitmap(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}) {}
AggregateFunctionBitmap(const DataTypePtr & type) : IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}) { }
String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<T>>();
}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
@ -37,15 +32,9 @@ public:
this->data(place).rbs.merge(this->data(rhs).rbs);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).rbs.write(buf);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).rbs.write(buf); }
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).rbs.read(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).rbs.read(buf); }
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{
@ -59,23 +48,22 @@ class AggregateFunctionBitmapL2 final : public IAggregateFunctionDataHelper<Data
{
public:
AggregateFunctionBitmapL2(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}){}
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {})
{
}
String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<T>>();
}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
Data & data_lhs = this->data(place);
const Data & data_rhs = this->data(assert_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num]);
if (!data_lhs.doneFirst)
if (!data_lhs.init)
{
data_lhs.doneFirst = true;
data_lhs.rbs.rb_or(data_rhs.rbs);
data_lhs.init = true;
data_lhs.rbs.merge(data_rhs.rbs);
}
else
{
@ -88,13 +76,13 @@ public:
Data & data_lhs = this->data(place);
const Data & data_rhs = this->data(rhs);
if (!data_rhs.doneFirst)
if (!data_rhs.init)
return;
if (!data_lhs.doneFirst)
if (!data_lhs.init)
{
data_lhs.doneFirst = true;
data_lhs.rbs.rb_or(data_rhs.rbs);
data_lhs.init = true;
data_lhs.rbs.merge(data_rhs.rbs);
}
else
{
@ -102,15 +90,9 @@ public:
}
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).rbs.write(buf);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).rbs.write(buf); }
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).rbs.read(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).rbs.read(buf); }
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{
@ -122,39 +104,30 @@ template <typename Data>
class BitmapAndPolicy
{
public:
static void apply(Data& lhs, const Data& rhs)
{
lhs.rbs.rb_and(rhs.rbs);
}
static void apply(Data & lhs, const Data & rhs) { lhs.rbs.rb_and(rhs.rbs); }
};
template <typename Data>
class BitmapOrPolicy
{
public:
static void apply(Data& lhs, const Data& rhs)
{
lhs.rbs.rb_or(rhs.rbs);
}
static void apply(Data & lhs, const Data & rhs) { lhs.rbs.rb_or(rhs.rbs); }
};
template <typename Data>
class BitmapXorPolicy
{
public:
static void apply(Data& lhs, const Data& rhs)
{
lhs.rbs.rb_xor(rhs.rbs);
}
static void apply(Data & lhs, const Data & rhs) { lhs.rbs.rb_xor(rhs.rbs); }
};
template <typename T, typename Data>
using AggregateFunctionBitmapL2And = AggregateFunctionBitmapL2<T, Data, BitmapAndPolicy<Data> >;
using AggregateFunctionBitmapL2And = AggregateFunctionBitmapL2<T, Data, BitmapAndPolicy<Data>>;
template <typename T, typename Data>
using AggregateFunctionBitmapL2Or = AggregateFunctionBitmapL2<T, Data, BitmapOrPolicy<Data> >;
using AggregateFunctionBitmapL2Or = AggregateFunctionBitmapL2<T, Data, BitmapOrPolicy<Data>>;
template <typename T, typename Data>
using AggregateFunctionBitmapL2Xor = AggregateFunctionBitmapL2<T, Data, BitmapXorPolicy<Data> >;
using AggregateFunctionBitmapL2Xor = AggregateFunctionBitmapL2<T, Data, BitmapXorPolicy<Data>>;
}

File diff suppressed because it is too large Load Diff