Return float in segmentLengthSum for float args, add tests

This commit is contained in:
vdimir 2021-05-20 17:43:24 +03:00
parent c8cbde1b89
commit 09d63545b0
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
3 changed files with 39 additions and 21 deletions

View File

@ -111,14 +111,15 @@ template <typename T, typename Data>
class AggregateFunctionSegmentLengthSum final : public IAggregateFunctionDataHelper<Data, AggregateFunctionSegmentLengthSum<T, Data>>
{
private:
UInt64 getSegmentLengthSum(Data & data) const
template <typename TResult>
TResult getSegmentLengthSum(Data & data) const
{
if (data.size() == 0)
return 0;
data.sort();
UInt64 res = 0;
TResult res = 0;
typename Data::Segment cur_segment = data.segments[0];
@ -146,7 +147,12 @@ public:
{
}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
DataTypePtr getReturnType() const override
{
if constexpr (std::is_floating_point_v<T>)
return std::make_shared<DataTypeFloat64>();
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; }
@ -183,7 +189,10 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(getSegmentLengthSum(this->data(place)));
if constexpr (std::is_floating_point_v<T>)
assert_cast<ColumnFloat64 &>(to).getData().push_back(getSegmentLengthSum<Float64>(this->data(place)));
else
assert_cast<ColumnUInt64 &>(to).getData().push_back(getSegmentLengthSum<UInt64>(this->data(place)));
}
};

View File

@ -1,2 +1,6 @@
a 5
b 8
a 5 UInt64
b 8 UInt64
c 3 UInt64
a 1 Float64
a 7201 UInt64
a 7 UInt64

View File

@ -1,21 +1,26 @@
DROP TABLE IF EXISTS segment;
DROP TABLE IF EXISTS fl_segment;
DROP TABLE IF EXISTS dt_segment;
DROP TABLE IF EXISTS date_segment;
CREATE TABLE segment
(
`id` String,
`start` UInt64,
`end` UInt64
)
ENGINE = MergeTree
ORDER BY start;
CREATE TABLE segment ( `id` String, `start` Int64, `end` Int64 ) ENGINE = MergeTree ORDER BY start;
INSERT INTO segment VALUES ('a', 1, 3), ('a', 1, 3), ('a', 2, 4), ('a', 1, 1), ('a', 5, 6), ('a', 5, 7), ('b', 10, 12), ('b', 13, 19), ('b', 14, 16), ('c', -1, 1), ('c', -2, -1);
INSERT INTO segment VALUES ('a', 1, 3), ('a', 2, 4), ('a', 5, 6), ('a', 5, 7), ('b', 10, 12), ('b', 13, 19), ('b', 14, 16);
CREATE TABLE fl_segment ( `id` String, `start` Float, `end` Float ) ENGINE = MergeTree ORDER BY start;
INSERT INTO fl_segment VALUES ('a', 1.1, 3.2), ('a', 1.5, 3.6), ('a', 4.0, 5.0);
SELECT
id,
segmentLengthSum(start, end)
FROM segment
GROUP BY id
ORDER BY id;
CREATE TABLE dt_segment ( `id` String, `start` DateTime, `end` DateTime ) ENGINE = MergeTree ORDER BY start;
INSERT INTO dt_segment VALUES ('a', '2020-01-01 02:11:22', '2020-01-01 03:12:31'), ('a', '2020-01-01 01:12:30', '2020-01-01 02:50:11');
CREATE TABLE date_segment ( `id` String, `start` Date, `end` Date ) ENGINE = MergeTree ORDER BY start;
INSERT INTO date_segment VALUES ('a', '2020-01-01', '2020-01-04'), ('a', '2020-01-03', '2020-01-08 02:50:11');
SELECT id, segmentLengthSum(start, end), toTypeName(segmentLengthSum(start, end)) FROM segment GROUP BY id ORDER BY id;
SELECT id, 3.4 < segmentLengthSum(start, end) AND segmentLengthSum(start, end) < 3.6, toTypeName(segmentLengthSum(start, end)) FROM fl_segment GROUP BY id ORDER BY id;
SELECT id, segmentLengthSum(start, end), toTypeName(segmentLengthSum(start, end)) FROM dt_segment GROUP BY id ORDER BY id;
SELECT id, segmentLengthSum(start, end), toTypeName(segmentLengthSum(start, end)) FROM date_segment GROUP BY id ORDER BY id;
DROP TABLE segment;
DROP TABLE fl_segment;
DROP TABLE dt_segment;
DROP TABLE date_segment;