This commit is contained in:
Alexey Milovidov 2018-03-14 07:36:41 +03:00
parent c6aac7c03c
commit 688d277ad4
7 changed files with 223 additions and 279 deletions

View File

@ -1,148 +0,0 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionIntersectionsMax.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
namespace DB
{
template <typename T>
typename Intersections<T>::PointsMap::iterator Intersections<T>::insert_point(const T & v)
{
auto res = points.emplace(v, 0);
auto & i = res.first;
if (!res.second)
return i;
if (i == points.begin())
return i;
auto prev = i;
prev--;
i->second = prev->second;
return i;
}
template <typename T>
void Intersections<T>::add(const T & start, const T & end, T weight)
{
auto sp = insert_point(start);
auto ep = end ? insert_point(end) : points.end();
do
{
sp->second += weight;
if (sp->second > max_weight)
{
max_weight = sp->second;
max_weight_pos = sp->first;
}
} while (++sp != ep);
}
template <typename T>
void Intersections<T>::merge(const Intersections & other)
{
if (other.points.empty())
return;
typename PointsMap::const_iterator prev, i = other.points.begin();
prev = i;
i++;
while (i != other.points.end())
{
add(prev->first, i->first, prev->second);
prev = i;
i++;
}
if (prev != other.points.end())
add(prev->first, 0, prev->second);
}
template <typename T>
void Intersections<T>::serialize(WriteBuffer & buf) const
{
writeBinary(points.size(), buf);
for (const auto & p : points)
{
writeBinary(p.first, buf);
writeBinary(p.second, buf);
}
}
template <typename T>
void Intersections<T>::deserialize(ReadBuffer & buf)
{
std::size_t size;
T point;
T weight;
readBinary(size, buf);
for (std::size_t i = 0; i < size; ++i)
{
readBinary(point, buf);
readBinary(weight, buf);
points.emplace(point, weight);
}
}
void AggregateFunctionIntersectionsMax::_add(
AggregateDataPtr place, const IColumn & column_start, const IColumn & column_end, size_t row_num) const
{
PointType start_time, end_time;
Field tmp_start_time_field, tmp_end_time_field;
column_start.get(row_num, tmp_start_time_field);
if (tmp_start_time_field.isNull())
return;
start_time = tmp_start_time_field.template get<PointType>();
if (0 == start_time)
return;
column_end.get(row_num, tmp_end_time_field);
if (tmp_end_time_field.isNull())
{
end_time = 0;
}
else
{
end_time = tmp_end_time_field.template get<PointType>();
if (0 != end_time)
{
if (end_time == start_time)
{
end_time = 0;
}
else if (end_time < start_time)
{
return;
}
}
}
data(place).add(start_time, end_time);
}
namespace
{
AggregateFunctionPtr createAggregateFunctionIntersectionsMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertBinary(name, argument_types);
return std::make_shared<AggregateFunctionIntersectionsMax>(argument_types, parameters, false);
}
AggregateFunctionPtr createAggregateFunctionIntersectionsMaxPos(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertBinary(name, argument_types);
return std::make_shared<AggregateFunctionIntersectionsMax>(argument_types, parameters, true);
}
}
void registerAggregateFunctionIntersectionsMax(AggregateFunctionFactory & factory)
{
factory.registerFunction("intersectionsMax", createAggregateFunctionIntersectionsMax, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("intersectionsMaxPos", createAggregateFunctionIntersectionsMaxPos, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -1,129 +0,0 @@
#pragma once
#include <common/logger_useful.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnNullable.h>
#include <Common/Allocator.h>
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int AGGREGATE_FUNCTION_DOESNT_ALLOW_PARAMETERS;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
template <typename T>
class Intersections final
{
using PointsMap = std::map<T, T>;
PointsMap points;
T max_weight;
T max_weight_pos;
typename PointsMap::iterator insert_point(const T & v);
public:
Intersections() : max_weight(0) {}
void add(const T & start, const T & end, T weight = 1);
void merge(const Intersections & other);
void serialize(WriteBuffer & buf) const;
void deserialize(ReadBuffer & buf);
T max() const
{
return max_weight;
}
T max_pos() const
{
return max_weight_pos;
}
};
class AggregateFunctionIntersectionsMax final
: public IAggregateFunctionDataHelper<Intersections<UInt64>, AggregateFunctionIntersectionsMax>
{
using PointType = UInt64;
bool return_position;
void _add(AggregateDataPtr place, const IColumn & column_start, const IColumn & column_end, size_t row_num) const;
public:
AggregateFunctionIntersectionsMax(const DataTypes & arguments, const Array & params, bool return_position)
: return_position(return_position)
{
if (!params.empty())
{
throw Exception(
"Aggregate function " + getName() + " does not allow paremeters.", ErrorCodes::AGGREGATE_FUNCTION_DOESNT_ALLOW_PARAMETERS);
}
if (arguments.size() != 2)
throw Exception("Aggregate function " + getName() + " requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[0]->isValueRepresentedByInteger())
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[1]->isValueRepresentedByInteger())
throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[0]->equals(*arguments[1]))
throw Exception{getName() + ": arguments must have the same type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
String getName() const override
{
return "IntersectionsMax";
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
_add(place, *columns[0], *columns[1], row_num);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(data(rhs));
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).deserialize(buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto & ret = static_cast<ColumnUInt64 &>(to).getData();
ret.push_back(data(place).max());
if (return_position)
ret.push_back(data(place).max_pos());
}
const char * getHeaderFilePath() const override
{
return __FILE__;
}
};
}

View File

@ -0,0 +1,32 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionMaxIntersections.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
namespace DB
{
namespace
{
AggregateFunctionPtr createAggregateFunctionMaxIntersections(
AggregateFunctionIntersectionsKind kind,
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertBinary(name, argument_types);
assertNoParameters(name, parameters);
return AggregateFunctionPtr{createWithNumericType<AggregateFunctionIntersectionsMax>(*argument_types[0], kind, argument_types)};
}
}
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory & factory)
{
factory.registerFunction("maxIntersections", [](const std::string & name, const DataTypes & argument_types, const Array & parameters)
{ return createAggregateFunctionMaxIntersections(AggregateFunctionIntersectionsKind::Count, name, argument_types, parameters); });
factory.registerFunction("maxIntersectionsPosition", [](const std::string & name, const DataTypes & argument_types, const Array & parameters)
{ return createAggregateFunctionMaxIntersections(AggregateFunctionIntersectionsKind::Position, name, argument_types, parameters); });
}
}

View File

@ -0,0 +1,167 @@
#pragma once
#include <common/logger_useful.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Common/ArenaAllocator.h>
#include <AggregateFunctions/IAggregateFunction.h>
#define AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE 0xFFFFFF
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TOO_LARGE_ARRAY_SIZE;
}
/** maxIntersections: returns maximum count of the intersected intervals defined by start_column and end_column values,
* maxIntersectionsPosition: returns leftmost position of maximum intersection of intervals.
*/
/// Similar to GroupArrayNumericData.
template <typename T>
struct MaxIntersectionsData
{
/// Left or right end of the interval and signed weight; with positive sign for begin of interval and negative sign for end of interval.
using Value = std::pair<T, Int64>;
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
using Allocator = MixedArenaAllocator<4096>;
using Array = PODArray<Value, 32, Allocator>;
Array value;
};
enum class AggregateFunctionIntersectionsKind
{
Count,
Position
};
template <typename PointType>
class AggregateFunctionIntersectionsMax final
: public IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>
{
private:
AggregateFunctionIntersectionsKind kind;
public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: kind(kind_)
{
if (!arguments[0]->isNumber())
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[1]->isNumber())
throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[0]->equals(*arguments[1]))
throw Exception{getName() + ": arguments must have the same type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
String getName() const override
{
return kind == AggregateFunctionIntersectionsKind::Count
? "maxIntersections"
: "maxIntersectionsPosition";
}
DataTypePtr getReturnType() const override
{
if (kind == AggregateFunctionIntersectionsKind::Count)
return std::make_shared<DataTypeUInt64>();
else
return std::make_shared<DataTypeNumber<PointType>>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
PointType left = static_cast<const ColumnVector<PointType> &>(*columns[0]).getData()[row_num];
PointType right = static_cast<const ColumnVector<PointType> &>(*columns[1]).getData()[row_num];
this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena);
this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & cur_elems = this->data(place);
auto & rhs_elems = this->data(rhs);
cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
const auto & value = this->data(place).value;
size_t size = value.size();
writeVarUInt(size, buf);
buf.write(reinterpret_cast<const char *>(&value[0]), size * sizeof(value[0]));
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
size_t size = 0;
readVarUInt(size, buf);
if (unlikely(size > AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE))
throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
auto & value = this->data(place).value;
value.resize(size, arena);
buf.read(reinterpret_cast<char *>(&value[0]), size * sizeof(value[0]));
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
Int64 current_intersections = 0;
Int64 max_intersections = 0;
PointType position_of_max_intersections = 0;
/// const_cast because we will sort the array
auto & array = const_cast<typename MaxIntersectionsData<PointType>::Array &>(this->data(place).value);
/// TODO NaNs?
std::sort(array.begin(), array.end(), [](const auto & a, const auto & b) { return a.first < b.first; });
for (const auto & point_weight : array)
{
current_intersections += point_weight.second;
if (current_intersections > max_intersections)
{
max_intersections = current_intersections;
position_of_max_intersections = point_weight.first;
}
}
if (kind == AggregateFunctionIntersectionsKind::Count)
{
auto & result_column = static_cast<ColumnUInt64 &>(to).getData();
result_column.push_back(max_intersections);
}
else
{
auto & result_column = static_cast<ColumnVector<PointType> &>(to).getData();
result_column.push_back(position_of_max_intersections);
}
}
const char * getHeaderFilePath() const override
{
return __FILE__;
}
};
}

View File

@ -23,7 +23,7 @@ void registerAggregateFunctionsUniq(AggregateFunctionFactory &);
void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &); void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &);
void registerAggregateFunctionTopK(AggregateFunctionFactory &); void registerAggregateFunctionTopK(AggregateFunctionFactory &);
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &); void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
void registerAggregateFunctionIntersectionsMax(AggregateFunctionFactory &); void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -54,7 +54,7 @@ void registerAggregateFunctions()
registerAggregateFunctionUniqUpTo(factory); registerAggregateFunctionUniqUpTo(factory);
registerAggregateFunctionTopK(factory); registerAggregateFunctionTopK(factory);
registerAggregateFunctionsBitwise(factory); registerAggregateFunctionsBitwise(factory);
registerAggregateFunctionIntersectionsMax(factory); registerAggregateFunctionsMaxIntersections(factory);
} }
{ {

View File

@ -0,0 +1,20 @@
DROP TABLE IF EXISTS test.test;
CREATE TABLE test.test(start Integer, end Integer) engine = Memory;
INSERT INTO test.test(start,end) VALUES (1,3),(2,7),(3,999),(4,7),(5,8);
/*
1 2 3 4 5 6 7 8 9
------------------>
1---3
2---------7
3-------------
4-----7
5-----8
------------------>
1 2 3 3 4 4 4 2 1 //intersections count for each point
*/
SELECT maxIntersections(start,end) FROM test.test;
SELECT maxIntersectionsPosition(start,end) FROM test.test;
DROP TABLE test.test;