Merge pull request #14411 from nikitamikhaylov/rank-corr

Merging #11769 (Rank Correlation Spearman)
This commit is contained in:
Nikita Mikhaylov 2020-09-07 21:18:19 +04:00 committed by GitHub
commit 5d9367aea4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 327 additions and 0 deletions

View File

@ -0,0 +1,51 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionRankCorrelation.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include "registerAggregateFunctions.h"
#include <AggregateFunctions/Helpers.h>
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
}
namespace DB
{
namespace
{
AggregateFunctionPtr createAggregateFunctionRankCorrelation(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertBinary(name, argument_types);
assertNoParameters(name, parameters);
AggregateFunctionPtr res;
if (isDecimal(argument_types[0]) || isDecimal(argument_types[1]))
{
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
}
else
{
res.reset(createWithTwoNumericTypes<AggregateFunctionRankCorrelation>(*argument_types[0], *argument_types[1], argument_types));
}
if (!res)
{
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
}
return res;
}
}
void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory & factory)
{
factory.registerFunction("rankCorr", createAggregateFunctionRankCorrelation, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -0,0 +1,234 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnTuple.h>
#include <Common/assert_cast.h>
#include <Common/FieldVisitors.h>
#include <Core/Types.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <limits>
#include <DataTypes/DataTypeArray.h>
#include <Common/ArenaAllocator.h>
#include <type_traits>
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
namespace DB
{
template <template <typename> class Comparator>
struct ComparePairFirst final
{
template <typename X, typename Y>
bool operator()(const std::pair<X, Y> & lhs, const std::pair<X, Y> & rhs) const
{
return Comparator<X>{}(lhs.first, rhs.first);
}
};
template <template <typename> class Comparator>
struct ComparePairSecond final
{
template <typename X, typename Y>
bool operator()(const std::pair<X, Y> & lhs, const std::pair<X, Y> & rhs) const
{
return Comparator<Y>{}(lhs.second, rhs.second);
}
};
template <typename X = Float64, typename Y = Float64>
struct AggregateFunctionRankCorrelationData final
{
size_t size_x = 0;
using Allocator = MixedAlignedArenaAllocator<alignof(std::pair<X, Y>), 4096>;
using Array = PODArray<std::pair<X, Y>, 32, Allocator>;
Array values;
};
template <typename X, typename Y>
class AggregateFunctionRankCorrelation :
public IAggregateFunctionDataHelper<AggregateFunctionRankCorrelationData<X, Y>, AggregateFunctionRankCorrelation<X, Y>>
{
using Data = AggregateFunctionRankCorrelationData<X, Y>;
using Allocator = MixedAlignedArenaAllocator<alignof(std::pair<Float64, Float64>), 4096>;
using Array = PODArray<std::pair<Float64, Float64>, 32, Allocator>;
public:
explicit AggregateFunctionRankCorrelation(const DataTypes & arguments)
:IAggregateFunctionDataHelper<AggregateFunctionRankCorrelationData<X, Y>,AggregateFunctionRankCorrelation<X, Y>> ({arguments}, {})
{}
String getName() const override
{
return "rankCorr";
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void insert(Data & a, const std::pair<X, Y> & x, Arena * arena) const
{
++a.size_x;
a.values.push_back(x, arena);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
auto & a = this->data(place);
auto new_x = assert_cast<const ColumnVector<X> &>(*columns[0]).getData()[row_num];
auto new_y = assert_cast<const ColumnVector<Y> &>(*columns[1]).getData()[row_num];
auto new_arg = std::make_pair(new_x, new_y);
a.size_x += 1;
a.values.push_back(new_arg, arena);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & a = this->data(place);
auto & b = this->data(rhs);
if (b.size_x)
for (size_t i = 0; i < b.size_x; ++i)
insert(a, b.values[i], arena);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
const auto & value = this->data(place).values;
size_t size = this->data(place).size_x;
writeVarUInt(size, buf);
buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
size_t size = 0;
readVarUInt(size, buf);
auto & value = this->data(place).values;
value.resize(size, arena);
buf.read(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
}
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override
{
const auto & value = this->data(place).values;
size_t size = this->data(place).size_x;
if (size < 2)
{
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
}
//create a copy of values not to format data
PODArrayWithStackMemory<std::pair<Float64, Float64>, 32> tmp_values;
tmp_values.resize(size);
for (size_t j = 0; j < size; ++ j)
tmp_values[j] = static_cast<std::pair<Float64, Float64>>(value[j]);
//sort x_values
std::sort(std::begin(tmp_values), std::end(tmp_values), ComparePairFirst<std::greater>{});
for (size_t j = 0; j < size;)
{
//replace x_values with their ranks
size_t rank = j + 1;
size_t same = 1;
size_t cur_sum = rank;
size_t cur_start = j;
while (j < size - 1)
{
if (tmp_values[j].first == tmp_values[j + 1].first)
{
// rank of (j + 1)th number
rank += 1;
same++;
cur_sum += rank;
j++;
}
else
break;
}
// insert rank is calculated as average of ranks of equal values
Float64 insert_rank = static_cast<Float64>(cur_sum) / same;
for (size_t i = cur_start; i <= j; ++i)
tmp_values[i].first = insert_rank;
j++;
}
//sort y_values
std::sort(std::begin(tmp_values), std::end(tmp_values), ComparePairSecond<std::greater>{});
//replace y_values with their ranks
for (size_t j = 0; j < size;)
{
//replace x_values with their ranks
size_t rank = j + 1;
size_t same = 1;
size_t cur_sum = rank;
size_t cur_start = j;
while (j < size - 1)
{
if (tmp_values[j].second == tmp_values[j + 1].second)
{
// rank of (j + 1)th number
rank += 1;
same++;
cur_sum += rank;
j++;
}
else
{
break;
}
}
// insert rank is calculated as average of ranks of equal values
Float64 insert_rank = static_cast<Float64>(cur_sum) / same;
for (size_t i = cur_start; i <= j; ++i)
tmp_values[i].second = insert_rank;
j++;
}
//count d^2 sum
Float64 answer = static_cast<Float64>(0);
for (size_t j = 0; j < size; ++ j)
answer += (tmp_values[j].first - tmp_values[j].second) * (tmp_values[j].first - tmp_values[j].second);
answer *= 6;
answer /= size * (size * size - 1);
answer = 1 - answer;
auto & column = static_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(answer);
}
};
};

View File

@ -45,6 +45,7 @@ void registerAggregateFunctions()
registerAggregateFunctionMoving(factory);
registerAggregateFunctionCategoricalIV(factory);
registerAggregateFunctionAggThrow(factory);
registerAggregateFunctionRankCorrelation(factory);
}
{

View File

@ -35,6 +35,7 @@ void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &)
void registerAggregateFunctionMoving(AggregateFunctionFactory &);
void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &);
void registerAggregateFunctionAggThrow(AggregateFunctionFactory &);
void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory &);
class AggregateFunctionCombinatorFactory;
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);

View File

@ -32,6 +32,7 @@ SRCS(
AggregateFunctionNull.cpp
AggregateFunctionOrFill.cpp
AggregateFunctionQuantile.cpp
AggregateFunctionRankCorrelation.cpp
AggregateFunctionResample.cpp
AggregateFunctionRetention.cpp
AggregateFunctionSequenceMatch.cpp

View File

@ -0,0 +1,10 @@
1
1
-1
-1
-0.037
-0.037
-0.108
-0.108
0.286
0.286

View File

@ -0,0 +1,29 @@
CREATE DATABASE IF NOT EXISTS db_01455_rank_correlation;
USE db_01455_rank_correlation;
DROP TABLE IF EXISTS moons;
DROP TABLE IF EXISTS circles;
SELECT '1';
SELECT rankCorr(number, number) FROM numbers(100);
SELECT '-1';
SELECT rankCorr(number, -1 * number) FROM numbers(100);
SELECT '-0.037';
SELECT roundBankers(rankCorr(exp(number), sin(number)), 3) FROM numbers(100);
CREATE TABLE moons(a Float64, b Float64) Engine=Memory();
INSERT INTO moons VALUES (1.230365,1.291454), (1.93851,0.6499), (1.574085,0.744109), (1.416457,1.41872), (1.90165,1.298199), (2.023844,1.142459), (1.828602,0.636404), (1.568649,1.157387), (1.968863,1.160039), (1.790198,0.860815), (1.238993,0.252486), (1.690338,0.573545), (1.678741,0.739649), (1.363346,0.514698), (1.924442,0.484331), (0.849071,0.585017), (1.859407,1.098124), (1.657176,1.314958), (1.085181,0.761741), (1.184481,0.639135), (1.59856,0.688384), (1.304818,1.212579), (1.913821,0.663551), (1.872619,0.510627), (1.29273,0.795267), (1.767669,0.892397), (1.790311,1.21813), (1.621893,1.229768), (1.525505,0.752643), (1.513535,1.016012), (1.120456,1.427238), (1.71505,0.716654), (1.394756,0.733629), (1.746027,1.422821), (1.5376,1.387397), (1.358968,0.575393), (1.941569,0.572639), (1.904995,0.966926), (1.967455,0.436449), (2.045535,0.582434), (1.365599,0.446582), (2.035874,0.468542), (1.419283,0.739308), (1.718267,0.895579), (1.285871,1.014628), (2.010657,1.631207), (1.78226,0.576882), (1.78274,0.727585), (1.454934,1.285701), (1.657208,0.581418);
SELECT '-0.108';
SELECT roundBankers(rankCorr(a, b), 3) from moons;
CREATE TABLE circles(a Float64, b Float64) Engine=Memory();
INSERT INTO circles VALUES (1.20848,0.505643), (1.577706,1.726383), (1.945215,1.638926), (0.493616,0.792443), (0.827802,1.41133), (1.012179,1.654582), (1.815329,0.254426), (-0.068102,1.456476), (1.235432,1.565291), (1.269633,1.857153), (0.687433,1.24911), (0.131356,1.610389), (1.991372,0.204134), (1.678587,1.456911), (0.501133,0.68513), (0.924535,0.541514), (0.574115,0.340542), (-0.013384,1.17037), (0.917257,1.799431), (1.364786,0.396457), (1.931339,1.093935), (0.575076,0.427512), (2.084798,1.752707), (0.694029,0.257422), (-0.003821,0.160859), (0.037966,0.217695), (1.986527,1.249144), (1.864518,0.521483), (0.038928,0.175741), (1.855737,1.678827), (0.779503,0.963619), (0.035384,0.238397), (0.136108,0.128737), (0.0581,1.093712), (-0.012542,0.713137), (1.53441,0.447265), (0.198885,1.232961), (1.66781,0.259156), (1.478017,1.256315), (1.148358,1.659979), (0.340698,0.76793), (0.376184,0.578202), (0.251495,1.765917), (1.836389,1.75769), (1.573166,1.753943), (0.448309,0.965337), (1.704437,1.138451), (1.93234,1.723736), (1.412218,0.603027), (1.978789,0.938132);
SELECT '0.286';
SELECT roundBankers(rankCorr(a, b), 3) from circles;
DROP TABLE IF EXISTS moons;
DROP TABLE IF EXISTS circles;
DROP DATABASE IF EXISTS db_01455_rank_correlation;