Merging abandoned pull request with "boundingRatio" function #3139

This commit is contained in:
Alexey Milovidov 2018-12-20 18:14:32 +03:00
parent 113ff56384
commit 731d76821d
3 changed files with 99 additions and 71 deletions

View File

@ -1,18 +1,10 @@
#pragma once
#include <iostream>
#include <sstream>
#include <unordered_set>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeNumberBase.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include <Common/FieldVisitors.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Common/ArenaAllocator.h>
#include <Common/typeid_cast.h>
#include <ext/range.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
@ -20,61 +12,80 @@
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
/** Tracks the leftmost and rightmost (x, y) data points.
*/
struct AggregateFunctionBoundingRatioData
{
using TimestampEvent = std::pair<UInt32, Float64>;
bool is_first = false;
TimestampEvent first_event;
TimestampEvent last_event;
void add(UInt32 timestamp, Float64 f)
struct Point
{
if (is_first)
Float64 x;
Float64 y;
};
bool empty = true;
Point left;
Point right;
void add(Float64 x, Float64 y)
{
Point point{x, y};
if (empty)
{
first_event = TimestampEvent{timestamp, f};
is_first = true;
left = point;
right = point;
empty = false;
}
else
else if (point.x < left.x)
{
last_event = TimestampEvent{timestamp, f};
left = point;
}
else if (point.x > right.x)
{
right = point;
}
}
void merge(const AggregateFunctionBoundingRatioData & other)
{
// if the arg is earlier than us, replace us with them
if (other.first_event.first < first_event.first)
if (empty)
{
first_event = other.first_event;
*this = other;
}
// if the arg is _later_ than us, replace us with them
if (other.last_event.first > last_event.second)
else
{
last_event = other.last_event;
if (other.left.x < left.x)
left = other.left;
if (other.right.x > right.x)
right = other.right;
}
}
void serialize(WriteBuffer & buf) const
{
writeBinary(is_first, buf);
writeBinary(first_event.first, buf);
writeBinary(first_event.second, buf);
writeBinary(empty, buf);
writeBinary(last_event.first, buf);
writeBinary(last_event.second, buf);
if (!empty)
{
writePODBinary(left, buf);
writePODBinary(right, buf);
}
}
void deserialize(ReadBuffer & buf)
{
readBinary(is_first, buf);
readBinary(empty, buf);
readBinary(first_event.first, buf);
readBinary(first_event.second, buf);
readBinary(last_event.first, buf);
readBinary(last_event.second, buf);
if (!empty)
{
readPODBinary(left, buf);
readPODBinary(right, buf);
}
}
};
@ -82,21 +93,15 @@ struct AggregateFunctionBoundingRatioData
class AggregateFunctionBoundingRatio final : public IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>
{
private:
/* implements a basic derivative function
*
* (y2 - y1) / (x2 - x1)
*/
/** Calculates the slope of a line between leftmost and rightmost data points.
* (y2 - y1) / (x2 - x1)
*/
Float64 getBoundingRatio(const AggregateFunctionBoundingRatioData & data) const
{
if (data.first_event.first == 0)
return 0;
if (data.last_event.first == 0)
return 0;
// void divide by zero in denominator
if (data.last_event.first == data.first_event.first)
return 0;
if (data.empty)
return std::numeric_limits<Float64>::quiet_NaN();
return (data.last_event.second - data.first_event.second) / (data.last_event.first - data.first_event.first);
return (data.right.y - data.left.y) / (data.right.x - data.left.x);
}
public:
@ -107,21 +112,14 @@ public:
AggregateFunctionBoundingRatio(const DataTypes & arguments)
{
const auto x_arg = arguments.at(0).get();
const auto y_arg = arguments.at(0).get();
const auto time_arg = arguments.at(0).get();
if (!typeid_cast<const DataTypeDateTime *>(time_arg) && !typeid_cast<const DataTypeUInt32 *>(time_arg))
throw Exception {"Illegal type " + time_arg->getName() + " of first argument of aggregate function " + getName()
+ ", must be DateTime or UInt32"};
const auto number_arg = arguments.at(1).get();
if (!number_arg->isNumber())
throw Exception {"Illegal type " + number_arg->getName() + " of argument " + toString(1) + " of aggregate function " + getName()
+ ", must be a Number",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!x_arg->isValueRepresentedByNumber() || !y_arg->isValueRepresentedByNumber())
throw Exception("Illegal types of arguments of aggregate function " + getName() + ", must have number representation.",
ErrorCodes::BAD_ARGUMENTS);
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
@ -129,9 +127,10 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
const auto timestamp = static_cast<const ColumnVector<UInt32> *>(columns[0])->getData()[row_num];
const auto value = static_cast<const ColumnVector<Float64> *>(columns[1])->getData()[row_num];
data(place).add(timestamp, value);
/// TODO Inefficient.
const auto x = applyVisitor(FieldVisitorConvertToNumber<Float64>(), (*columns[0])[row_num]);
const auto y = applyVisitor(FieldVisitorConvertToNumber<Float64>(), (*columns[1])[row_num]);
data(place).add(x, y);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override

View File

@ -1,2 +1,21 @@
1
1
1.5
1.5
1.5
0 1.5
1 1.5
2 1.5
3 1.5
4 1.5
5 1.5
6 1.5
7 1.5
8 1.5
9 1.5
0 1.5
1.5
nan
nan
1

View File

@ -3,14 +3,24 @@ drop table if exists rate_test;
create table rate_test (timestamp UInt32, event UInt32) engine=Memory;
insert into rate_test values (0,1000),(1,1001),(2,1002),(3,1003),(4,1004),(5,1005),(6,1006),(7,1007),(8,1008);
select 1.0 = rate(timestamp, event) from rate_test;
select 1.0 = boundingRatio(timestamp, event) from rate_test;
drop table if exists rate_test2;
create table rate_test2 (uid UInt32 default 1,timestamp DateTime, event UInt32) engine=Memory;
insert into rate_test2(timestamp, event) values ('2018-01-01 01:01:01',1001),('2018-01-01 01:01:02',1002),('2018-01-01 01:01:03',1003),('2018-01-01 01:01:04',1004),('2018-01-01 01:01:05',1005),('2018-01-01 01:01:06',1006),('2018-01-01 01:01:07',1007),('2018-01-01 01:01:08',1008);
insert into rate_test2(timestamp, event) values ('2018-01-01 01:01:01',1001),('2018-01-01 01:01:02',1002),('2018-01-01 01:01:03',1003),('2018-01-01 01:01:04',1004),('2018-01-01 01:01:05',1005),('2018-01-01 01:01:06',1006),('2018-01-01 01:01:07',1007),('2018-01-01 01:01:08',1008);
select 1.0 = rate(timestamp, event ) from rate_test2;
select 1.0 = boundingRatio(timestamp, event) from rate_test2;
drop table rate_test;
drop table rate_test2;
SELECT boundingRatio(number, number * 1.5) FROM numbers(10);
SELECT boundingRatio(1000 + number, number * 1.5) FROM numbers(10);
SELECT boundingRatio(1000 + number, number * 1.5 - 111) FROM numbers(10);
SELECT number % 10 AS k, boundingRatio(1000 + number, number * 1.5 - 111) FROM numbers(100) GROUP BY k WITH TOTALS ORDER BY k;
SELECT boundingRatio(1000 + number, number * 1.5 - 111) FROM numbers(2);
SELECT boundingRatio(1000 + number, number * 1.5 - 111) FROM numbers(1);
SELECT boundingRatio(1000 + number, number * 1.5 - 111) FROM numbers(1) WHERE 0;
SELECT boundingRatio(number, exp(number)) = e() - 1 FROM numbers(2);