ClickHouse/dbms/include/DB/AggregateFunctions/AggregateFunctionQuantileExactWeighted.h

302 lines
8.0 KiB
C++
Raw Normal View History

#pragma once
#include <DB/Common/HashTable/HashMap.h>
#include <DB/DataTypes/DataTypesNumber.h>
#include <DB/DataTypes/DataTypeArray.h>
#include <DB/AggregateFunctions/IBinaryAggregateFunction.h>
#include <DB/AggregateFunctions/QuantilesCommon.h>
#include <DB/Columns/ColumnArray.h>
namespace DB
{
2017-03-09 00:56:38 +00:00
/** The state is a hash table of the form: value -> how many times it happened.
*/
template <typename T>
struct AggregateFunctionQuantileExactWeightedData
{
using Key = T;
using Weight = UInt64;
2017-03-09 00:56:38 +00:00
/// When creating, the hash table must be small.
using Map = HashMap<
Key, Weight,
HashCRC32<Key>,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Key, Weight>) * (1 << 3)>
>;
Map map;
};
2017-03-09 00:56:38 +00:00
/** Exactly calculates a quantile over a set of values, for each of which a weight is given - how many times the value was encountered.
* You can consider a set of pairs `values, weight` - as a set of histograms,
* where value is the value rounded to the middle of the column, and weight is the height of the column.
* The argument type can only be a numeric type (including date and date-time).
* The result type is the same as the argument type.
*/
template <typename ValueType, typename WeightType>
class AggregateFunctionQuantileExactWeighted final
: public IBinaryAggregateFunction<
AggregateFunctionQuantileExactWeightedData<ValueType>,
AggregateFunctionQuantileExactWeighted<ValueType, WeightType>>
{
private:
double level;
DataTypePtr type;
public:
AggregateFunctionQuantileExactWeighted(double level_ = 0.5) : level(level_) {}
String getName() const override { return "quantileExactWeighted"; }
DataTypePtr getReturnType() const override
{
return type;
}
void setArgumentsImpl(const DataTypes & arguments)
{
type = arguments[0];
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
Squashed commit of the following: commit c567d4e1fe8d54e6363e47548f1e3927cc5ee78f Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 20:35:01 2017 +0300 Style [#METR-2944]. commit 26bf3e1228e03f46c29b13edb0e3770bd453e3f1 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 20:33:11 2017 +0300 Miscellaneous [#METR-2944]. commit eb946f4c6fd4bb0e9e5c7fb1468d36be3dfca5a5 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 20:30:19 2017 +0300 Miscellaneous [#METR-2944]. commit 78c867a14744b5af2db8d37caf7804fc2057ea51 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 20:11:41 2017 +0300 Miscellaneous [#METR-2944]. commit 6604c5c83cfcedc81c8da4da026711920d5963b4 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 19:56:15 2017 +0300 Miscellaneous [#METR-2944]. commit 23fbf05c1d4bead636458ec21b05a101b1152e33 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 19:47:52 2017 +0300 Miscellaneous [#METR-2944]. commit 98772faf11a7d450d473f7fa84f8a9ae24f7b59b Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 19:46:05 2017 +0300 Miscellaneous [#METR-2944]. commit 3dc636ab9f9359dbeac2e8d997ae563d4ca147e2 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 19:39:46 2017 +0300 Miscellaneous [#METR-2944]. commit 3e16aee95482f374ee3eda1a4dbe9ba5cdce02e8 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 19:38:03 2017 +0300 Miscellaneous [#METR-2944]. commit ae7e7e90eb1f82bd0fe0f887708d08b9e7755612 Author: Alexey Milovidov <milovidov@yandex-team.ru> Date: Fri Jan 6 19:34:15 2017 +0300 Miscellaneous [#METR-2944].
2017-01-06 17:41:19 +00:00
level = applyVisitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
}
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num, Arena *) const
{
this->data(place)
.map[static_cast<const ColumnVector<ValueType> &>(column_value).getData()[row_num]]
+= static_cast<const ColumnVector<WeightType> &>(column_weight).getData()[row_num];
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & map = this->data(place).map;
const auto & rhs_map = this->data(rhs).map;
for (const auto & pair : rhs_map)
map[pair.first] += pair.second;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).map.write(buf);
}
2016-09-22 23:26:08 +00:00
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::Reader reader(buf);
auto & map = this->data(place).map;
while (reader.next())
{
const auto & pair = reader.get();
2016-03-12 04:01:03 +00:00
map[pair.first] = pair.second;
}
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto & map = this->data(place).map;
size_t size = map.size();
if (0 == size)
{
static_cast<ColumnVector<ValueType> &>(to).getData().push_back(ValueType());
return;
}
2017-03-09 04:26:17 +00:00
/// Copy the data to a temporary array to get the element you need in order.
using Pair = typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::value_type;
2015-11-16 20:33:43 +00:00
std::unique_ptr<Pair[]> array_holder(new Pair[size]);
Pair * array = array_holder.get();
size_t i = 0;
UInt64 sum_weight = 0;
for (const auto & pair : map)
{
sum_weight += pair.second;
array[i] = pair;
++i;
}
std::sort(array, array + size, [](const Pair & a, const Pair & b) { return a.first < b.first; });
2016-03-13 19:00:59 +00:00
UInt64 threshold = std::ceil(sum_weight * level);
UInt64 accumulated = 0;
const Pair * it = array;
const Pair * end = array + size;
while (it < end)
{
accumulated += it->second;
if (accumulated >= threshold)
break;
++it;
}
if (it == end)
--it;
static_cast<ColumnVector<ValueType> &>(to).getData().push_back(it->first);
}
};
2017-03-09 00:56:38 +00:00
/** Same, but allows you to calculate several quantiles at once.
* For this, takes as parameters several levels. Example: quantilesExactWeighted(0.5, 0.8, 0.9, 0.95)(ConnectTiming, Weight).
* Returns an array of results.
*/
template <typename ValueType, typename WeightType>
class AggregateFunctionQuantilesExactWeighted final
: public IBinaryAggregateFunction<
AggregateFunctionQuantileExactWeightedData<ValueType>,
AggregateFunctionQuantilesExactWeighted<ValueType, WeightType>>
{
private:
QuantileLevels<double> levels;
DataTypePtr type;
public:
String getName() const override { return "quantilesExactWeighted"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(type);
}
void setArgumentsImpl(const DataTypes & arguments)
{
type = arguments[0];
}
void setParameters(const Array & params) override
{
levels.set(params);
}
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num, Arena *) const
{
this->data(place)
.map[static_cast<const ColumnVector<ValueType> &>(column_value).getData()[row_num]]
+= static_cast<const ColumnVector<WeightType> &>(column_weight).getData()[row_num];
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & map = this->data(place).map;
const auto & rhs_map = this->data(rhs).map;
for (const auto & pair : rhs_map)
map[pair.first] += pair.second;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).map.write(buf);
}
2016-09-22 23:26:08 +00:00
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::Reader reader(buf);
auto & map = this->data(place).map;
while (reader.next())
{
const auto & pair = reader.get();
2016-03-12 04:01:03 +00:00
map[pair.first] = pair.second;
}
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto & map = this->data(place).map;
size_t size = map.size();
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
size_t num_levels = levels.size();
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + num_levels);
if (!num_levels)
return;
typename ColumnVector<ValueType>::Container_t & data_to = static_cast<ColumnVector<ValueType> &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(old_size + num_levels);
if (0 == size)
{
for (size_t i = 0; i < num_levels; ++i)
data_to[old_size + i] = ValueType();
return;
}
2017-03-09 04:26:17 +00:00
/// Copy the data to a temporary array to get the element you need in order.
using Pair = typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::value_type;
2015-11-16 23:49:18 +00:00
std::unique_ptr<Pair[]> array_holder(new Pair[size]);
Pair * array = array_holder.get();
size_t i = 0;
UInt64 sum_weight = 0;
for (const auto & pair : map)
{
sum_weight += pair.second;
array[i] = pair;
++i;
}
std::sort(array, array + size, [](const Pair & a, const Pair & b) { return a.first < b.first; });
UInt64 accumulated = 0;
const Pair * it = array;
const Pair * end = array + size;
size_t level_index = 0;
2016-03-13 19:00:59 +00:00
UInt64 threshold = std::ceil(sum_weight * levels.levels[levels.permutation[level_index]]);
while (it < end)
{
accumulated += it->second;
while (accumulated >= threshold)
{
2016-03-13 18:15:41 +00:00
data_to[old_size + levels.permutation[level_index]] = it->first;
++level_index;
if (level_index == num_levels)
return;
2016-03-13 19:00:59 +00:00
threshold = std::ceil(sum_weight * levels.levels[levels.permutation[level_index]]);
}
++it;
}
while (level_index < num_levels)
{
2016-03-13 18:15:41 +00:00
data_to[old_size + levels.permutation[level_index]] = array[size - 1].first;
++level_index;
}
}
};
}