From 688d277ad4bcd3826f5ceb40ca32c1ca9c894132 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Wed, 14 Mar 2018 07:36:41 +0300 Subject: [PATCH] Adaptation #2012 --- .../AggregateFunctionIntersectionsMax.cpp | 148 ---------------- .../AggregateFunctionIntersectionsMax.h | 129 -------------- .../AggregateFunctionMaxIntersections.cpp | 32 ++++ .../AggregateFunctionMaxIntersections.h | 167 ++++++++++++++++++ .../registerAggregateFunctions.cpp | 4 +- ...ntersections_aggregate_functions.reference | 2 + ...0605_intersections_aggregate_functions.sql | 20 +++ 7 files changed, 223 insertions(+), 279 deletions(-) delete mode 100644 dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.cpp delete mode 100644 dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.h create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h create mode 100644 dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.reference create mode 100644 dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.sql diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.cpp b/dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.cpp deleted file mode 100644 index e9e6c739422..00000000000 --- a/dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include -#include -#include - -#include -#include - -namespace DB -{ -template -typename Intersections::PointsMap::iterator Intersections::insert_point(const T & v) -{ - auto res = points.emplace(v, 0); - auto & i = res.first; - if (!res.second) - return i; - if (i == points.begin()) - return i; - auto prev = i; - prev--; - i->second = prev->second; - return i; -} - -template -void Intersections::add(const T & start, const T & end, T weight) -{ - auto sp = insert_point(start); - auto ep = end ? insert_point(end) : points.end(); - do - { - sp->second += weight; - if (sp->second > max_weight) - { - max_weight = sp->second; - max_weight_pos = sp->first; - } - } while (++sp != ep); -} - -template -void Intersections::merge(const Intersections & other) -{ - if (other.points.empty()) - return; - - typename PointsMap::const_iterator prev, i = other.points.begin(); - prev = i; - i++; - - while (i != other.points.end()) - { - add(prev->first, i->first, prev->second); - prev = i; - i++; - } - - if (prev != other.points.end()) - add(prev->first, 0, prev->second); -} - -template -void Intersections::serialize(WriteBuffer & buf) const -{ - writeBinary(points.size(), buf); - for (const auto & p : points) - { - writeBinary(p.first, buf); - writeBinary(p.second, buf); - } -} - -template -void Intersections::deserialize(ReadBuffer & buf) -{ - std::size_t size; - T point; - T weight; - - readBinary(size, buf); - for (std::size_t i = 0; i < size; ++i) - { - readBinary(point, buf); - readBinary(weight, buf); - points.emplace(point, weight); - } -} - -void AggregateFunctionIntersectionsMax::_add( - AggregateDataPtr place, const IColumn & column_start, const IColumn & column_end, size_t row_num) const -{ - PointType start_time, end_time; - Field tmp_start_time_field, tmp_end_time_field; - - column_start.get(row_num, tmp_start_time_field); - if (tmp_start_time_field.isNull()) - return; - start_time = tmp_start_time_field.template get(); - if (0 == start_time) - return; - - column_end.get(row_num, tmp_end_time_field); - if (tmp_end_time_field.isNull()) - { - end_time = 0; - } - else - { - end_time = tmp_end_time_field.template get(); - if (0 != end_time) - { - if (end_time == start_time) - { - end_time = 0; - } - else if (end_time < start_time) - { - return; - } - } - } - - data(place).add(start_time, end_time); -} - -namespace -{ - AggregateFunctionPtr createAggregateFunctionIntersectionsMax( - const std::string & name, const DataTypes & argument_types, const Array & parameters) - { - assertBinary(name, argument_types); - return std::make_shared(argument_types, parameters, false); - } - - AggregateFunctionPtr createAggregateFunctionIntersectionsMaxPos( - const std::string & name, const DataTypes & argument_types, const Array & parameters) - { - assertBinary(name, argument_types); - return std::make_shared(argument_types, parameters, true); - } -} - -void registerAggregateFunctionIntersectionsMax(AggregateFunctionFactory & factory) -{ - factory.registerFunction("intersectionsMax", createAggregateFunctionIntersectionsMax, AggregateFunctionFactory::CaseInsensitive); - factory.registerFunction("intersectionsMaxPos", createAggregateFunctionIntersectionsMaxPos, AggregateFunctionFactory::CaseInsensitive); -} -} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.h b/dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.h deleted file mode 100644 index c15ae1ae33d..00000000000 --- a/dbms/src/AggregateFunctions/AggregateFunctionIntersectionsMax.h +++ /dev/null @@ -1,129 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include - -#include - -#include - -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int AGGREGATE_FUNCTION_DOESNT_ALLOW_PARAMETERS; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} - -template -class Intersections final -{ - using PointsMap = std::map; - - PointsMap points; - T max_weight; - T max_weight_pos; - - typename PointsMap::iterator insert_point(const T & v); - -public: - Intersections() : max_weight(0) {} - - void add(const T & start, const T & end, T weight = 1); - void merge(const Intersections & other); - - void serialize(WriteBuffer & buf) const; - void deserialize(ReadBuffer & buf); - - T max() const - { - return max_weight; - } - T max_pos() const - { - return max_weight_pos; - } -}; - -class AggregateFunctionIntersectionsMax final - : public IAggregateFunctionDataHelper, AggregateFunctionIntersectionsMax> -{ - using PointType = UInt64; - - bool return_position; - void _add(AggregateDataPtr place, const IColumn & column_start, const IColumn & column_end, size_t row_num) const; - -public: - AggregateFunctionIntersectionsMax(const DataTypes & arguments, const Array & params, bool return_position) - : return_position(return_position) - { - if (!params.empty()) - { - throw Exception( - "Aggregate function " + getName() + " does not allow paremeters.", ErrorCodes::AGGREGATE_FUNCTION_DOESNT_ALLOW_PARAMETERS); - } - - if (arguments.size() != 2) - throw Exception("Aggregate function " + getName() + " requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - - if (!arguments[0]->isValueRepresentedByInteger()) - throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - - if (!arguments[1]->isValueRepresentedByInteger()) - throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - - if (!arguments[0]->equals(*arguments[1])) - throw Exception{getName() + ": arguments must have the same type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - } - - String getName() const override - { - return "IntersectionsMax"; - } - - DataTypePtr getReturnType() const override - { - return std::make_shared(); - } - - void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override - { - _add(place, *columns[0], *columns[1], row_num); - } - - void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override - { - this->data(place).merge(data(rhs)); - } - - void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override - { - this->data(place).serialize(buf); - } - - void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override - { - this->data(place).deserialize(buf); - } - - void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override - { - auto & ret = static_cast(to).getData(); - ret.push_back(data(place).max()); - if (return_position) - ret.push_back(data(place).max_pos()); - } - - const char * getHeaderFilePath() const override - { - return __FILE__; - } -}; -} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp new file mode 100644 index 00000000000..e81dfe98259 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include + + +namespace DB +{ + +namespace +{ + AggregateFunctionPtr createAggregateFunctionMaxIntersections( + AggregateFunctionIntersectionsKind kind, + const std::string & name, const DataTypes & argument_types, const Array & parameters) + { + assertBinary(name, argument_types); + assertNoParameters(name, parameters); + + return AggregateFunctionPtr{createWithNumericType(*argument_types[0], kind, argument_types)}; + } +} + +void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory & factory) +{ + factory.registerFunction("maxIntersections", [](const std::string & name, const DataTypes & argument_types, const Array & parameters) + { return createAggregateFunctionMaxIntersections(AggregateFunctionIntersectionsKind::Count, name, argument_types, parameters); }); + + factory.registerFunction("maxIntersectionsPosition", [](const std::string & name, const DataTypes & argument_types, const Array & parameters) + { return createAggregateFunctionMaxIntersections(AggregateFunctionIntersectionsKind::Position, name, argument_types, parameters); }); +} + +} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h new file mode 100644 index 00000000000..ebc283c03fc --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h @@ -0,0 +1,167 @@ +#pragma once + +#include + +#include +#include + +#include +#include + +#include + +#include + +#define AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE 0xFFFFFF + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int TOO_LARGE_ARRAY_SIZE; +} + + +/** maxIntersections: returns maximum count of the intersected intervals defined by start_column and end_column values, + * maxIntersectionsPosition: returns leftmost position of maximum intersection of intervals. + */ + +/// Similar to GroupArrayNumericData. +template +struct MaxIntersectionsData +{ + /// Left or right end of the interval and signed weight; with positive sign for begin of interval and negative sign for end of interval. + using Value = std::pair; + + // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena + using Allocator = MixedArenaAllocator<4096>; + using Array = PODArray; + + Array value; +}; + +enum class AggregateFunctionIntersectionsKind +{ + Count, + Position +}; + +template +class AggregateFunctionIntersectionsMax final + : public IAggregateFunctionDataHelper, AggregateFunctionIntersectionsMax> +{ +private: + AggregateFunctionIntersectionsKind kind; + +public: + AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments) + : kind(kind_) + { + if (!arguments[0]->isNumber()) + throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + + if (!arguments[1]->isNumber()) + throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + + if (!arguments[0]->equals(*arguments[1])) + throw Exception{getName() + ": arguments must have the same type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + } + + String getName() const override + { + return kind == AggregateFunctionIntersectionsKind::Count + ? "maxIntersections" + : "maxIntersectionsPosition"; + } + + DataTypePtr getReturnType() const override + { + if (kind == AggregateFunctionIntersectionsKind::Count) + return std::make_shared(); + else + return std::make_shared>(); + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + PointType left = static_cast &>(*columns[0]).getData()[row_num]; + PointType right = static_cast &>(*columns[1]).getData()[row_num]; + + this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena); + this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + auto & cur_elems = this->data(place); + auto & rhs_elems = this->data(rhs); + + cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + const auto & value = this->data(place).value; + size_t size = value.size(); + writeVarUInt(size, buf); + buf.write(reinterpret_cast(&value[0]), size * sizeof(value[0])); + } + + void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override + { + size_t size = 0; + readVarUInt(size, buf); + + if (unlikely(size > AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE)) + throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE); + + auto & value = this->data(place).value; + + value.resize(size, arena); + buf.read(reinterpret_cast(&value[0]), size * sizeof(value[0])); + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + Int64 current_intersections = 0; + Int64 max_intersections = 0; + PointType position_of_max_intersections = 0; + + /// const_cast because we will sort the array + auto & array = const_cast::Array &>(this->data(place).value); + + /// TODO NaNs? + std::sort(array.begin(), array.end(), [](const auto & a, const auto & b) { return a.first < b.first; }); + + for (const auto & point_weight : array) + { + current_intersections += point_weight.second; + if (current_intersections > max_intersections) + { + max_intersections = current_intersections; + position_of_max_intersections = point_weight.first; + } + } + + if (kind == AggregateFunctionIntersectionsKind::Count) + { + auto & result_column = static_cast(to).getData(); + result_column.push_back(max_intersections); + } + else + { + auto & result_column = static_cast &>(to).getData(); + result_column.push_back(position_of_max_intersections); + } + } + + const char * getHeaderFilePath() const override + { + return __FILE__; + } +}; + +} diff --git a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp index d054107c705..644543e4f2e 100644 --- a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -23,7 +23,7 @@ void registerAggregateFunctionsUniq(AggregateFunctionFactory &); void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &); void registerAggregateFunctionTopK(AggregateFunctionFactory &); void registerAggregateFunctionsBitwise(AggregateFunctionFactory &); -void registerAggregateFunctionIntersectionsMax(AggregateFunctionFactory &); +void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &); void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &); @@ -54,7 +54,7 @@ void registerAggregateFunctions() registerAggregateFunctionUniqUpTo(factory); registerAggregateFunctionTopK(factory); registerAggregateFunctionsBitwise(factory); - registerAggregateFunctionIntersectionsMax(factory); + registerAggregateFunctionsMaxIntersections(factory); } { diff --git a/dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.reference b/dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.reference new file mode 100644 index 00000000000..61c83cba41c --- /dev/null +++ b/dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.reference @@ -0,0 +1,2 @@ +4 +5 diff --git a/dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.sql b/dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.sql new file mode 100644 index 00000000000..c23583dd8c8 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00605_intersections_aggregate_functions.sql @@ -0,0 +1,20 @@ +DROP TABLE IF EXISTS test.test; +CREATE TABLE test.test(start Integer, end Integer) engine = Memory; +INSERT INTO test.test(start,end) VALUES (1,3),(2,7),(3,999),(4,7),(5,8); + +/* +1 2 3 4 5 6 7 8 9 +------------------> +1---3 + 2---------7 + 3------------- + 4-----7 + 5-----8 +------------------> +1 2 3 3 4 4 4 2 1 //intersections count for each point +*/ + +SELECT maxIntersections(start,end) FROM test.test; +SELECT maxIntersectionsPosition(start,end) FROM test.test; + +DROP TABLE test.test;