mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 07:31:57 +00:00
Merge pull request #14411 from nikitamikhaylov/rank-corr
Merging #11769 (Rank Correlation Spearman)
This commit is contained in:
commit
5d9367aea4
51
src/AggregateFunctions/AggregateFunctionRankCorrelation.cpp
Normal file
51
src/AggregateFunctions/AggregateFunctionRankCorrelation.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionRankCorrelation.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionRankCorrelation(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
assertNoParameters(name, parameters);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
|
||||
if (isDecimal(argument_types[0]) || isDecimal(argument_types[1]))
|
||||
{
|
||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
else
|
||||
{
|
||||
res.reset(createWithTwoNumericTypes<AggregateFunctionRankCorrelation>(*argument_types[0], *argument_types[1], argument_types));
|
||||
}
|
||||
|
||||
if (!res)
|
||||
{
|
||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("rankCorr", createAggregateFunctionRankCorrelation, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
234
src/AggregateFunctions/AggregateFunctionRankCorrelation.h
Normal file
234
src/AggregateFunctions/AggregateFunctionRankCorrelation.h
Normal file
@ -0,0 +1,234 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Core/Types.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <limits>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <template <typename> class Comparator>
|
||||
struct ComparePairFirst final
|
||||
{
|
||||
template <typename X, typename Y>
|
||||
bool operator()(const std::pair<X, Y> & lhs, const std::pair<X, Y> & rhs) const
|
||||
{
|
||||
return Comparator<X>{}(lhs.first, rhs.first);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <template <typename> class Comparator>
|
||||
struct ComparePairSecond final
|
||||
{
|
||||
template <typename X, typename Y>
|
||||
bool operator()(const std::pair<X, Y> & lhs, const std::pair<X, Y> & rhs) const
|
||||
{
|
||||
return Comparator<Y>{}(lhs.second, rhs.second);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X = Float64, typename Y = Float64>
|
||||
struct AggregateFunctionRankCorrelationData final
|
||||
{
|
||||
size_t size_x = 0;
|
||||
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(std::pair<X, Y>), 4096>;
|
||||
using Array = PODArray<std::pair<X, Y>, 32, Allocator>;
|
||||
|
||||
Array values;
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
class AggregateFunctionRankCorrelation :
|
||||
public IAggregateFunctionDataHelper<AggregateFunctionRankCorrelationData<X, Y>, AggregateFunctionRankCorrelation<X, Y>>
|
||||
{
|
||||
using Data = AggregateFunctionRankCorrelationData<X, Y>;
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(std::pair<Float64, Float64>), 4096>;
|
||||
using Array = PODArray<std::pair<Float64, Float64>, 32, Allocator>;
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionRankCorrelation(const DataTypes & arguments)
|
||||
:IAggregateFunctionDataHelper<AggregateFunctionRankCorrelationData<X, Y>,AggregateFunctionRankCorrelation<X, Y>> ({arguments}, {})
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "rankCorr";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void insert(Data & a, const std::pair<X, Y> & x, Arena * arena) const
|
||||
{
|
||||
++a.size_x;
|
||||
a.values.push_back(x, arena);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
|
||||
auto new_x = assert_cast<const ColumnVector<X> &>(*columns[0]).getData()[row_num];
|
||||
auto new_y = assert_cast<const ColumnVector<Y> &>(*columns[1]).getData()[row_num];
|
||||
|
||||
auto new_arg = std::make_pair(new_x, new_y);
|
||||
|
||||
a.size_x += 1;
|
||||
|
||||
a.values.push_back(new_arg, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
auto & b = this->data(rhs);
|
||||
|
||||
if (b.size_x)
|
||||
for (size_t i = 0; i < b.size_x; ++i)
|
||||
insert(a, b.values[i], arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
const auto & value = this->data(place).values;
|
||||
size_t size = this->data(place).size_x;
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
auto & value = this->data(place).values;
|
||||
|
||||
value.resize(size, arena);
|
||||
buf.read(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override
|
||||
{
|
||||
const auto & value = this->data(place).values;
|
||||
size_t size = this->data(place).size_x;
|
||||
|
||||
if (size < 2)
|
||||
{
|
||||
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
//create a copy of values not to format data
|
||||
PODArrayWithStackMemory<std::pair<Float64, Float64>, 32> tmp_values;
|
||||
tmp_values.resize(size);
|
||||
for (size_t j = 0; j < size; ++ j)
|
||||
tmp_values[j] = static_cast<std::pair<Float64, Float64>>(value[j]);
|
||||
|
||||
//sort x_values
|
||||
std::sort(std::begin(tmp_values), std::end(tmp_values), ComparePairFirst<std::greater>{});
|
||||
|
||||
for (size_t j = 0; j < size;)
|
||||
{
|
||||
//replace x_values with their ranks
|
||||
size_t rank = j + 1;
|
||||
size_t same = 1;
|
||||
size_t cur_sum = rank;
|
||||
size_t cur_start = j;
|
||||
|
||||
while (j < size - 1)
|
||||
{
|
||||
if (tmp_values[j].first == tmp_values[j + 1].first)
|
||||
{
|
||||
// rank of (j + 1)th number
|
||||
rank += 1;
|
||||
same++;
|
||||
cur_sum += rank;
|
||||
j++;
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
|
||||
// insert rank is calculated as average of ranks of equal values
|
||||
Float64 insert_rank = static_cast<Float64>(cur_sum) / same;
|
||||
for (size_t i = cur_start; i <= j; ++i)
|
||||
tmp_values[i].first = insert_rank;
|
||||
j++;
|
||||
}
|
||||
|
||||
//sort y_values
|
||||
std::sort(std::begin(tmp_values), std::end(tmp_values), ComparePairSecond<std::greater>{});
|
||||
|
||||
//replace y_values with their ranks
|
||||
for (size_t j = 0; j < size;)
|
||||
{
|
||||
//replace x_values with their ranks
|
||||
size_t rank = j + 1;
|
||||
size_t same = 1;
|
||||
size_t cur_sum = rank;
|
||||
size_t cur_start = j;
|
||||
|
||||
while (j < size - 1)
|
||||
{
|
||||
if (tmp_values[j].second == tmp_values[j + 1].second)
|
||||
{
|
||||
// rank of (j + 1)th number
|
||||
rank += 1;
|
||||
same++;
|
||||
cur_sum += rank;
|
||||
j++;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// insert rank is calculated as average of ranks of equal values
|
||||
Float64 insert_rank = static_cast<Float64>(cur_sum) / same;
|
||||
for (size_t i = cur_start; i <= j; ++i)
|
||||
tmp_values[i].second = insert_rank;
|
||||
j++;
|
||||
}
|
||||
|
||||
//count d^2 sum
|
||||
Float64 answer = static_cast<Float64>(0);
|
||||
for (size_t j = 0; j < size; ++ j)
|
||||
answer += (tmp_values[j].first - tmp_values[j].second) * (tmp_values[j].first - tmp_values[j].second);
|
||||
|
||||
answer *= 6;
|
||||
answer /= size * (size * size - 1);
|
||||
answer = 1 - answer;
|
||||
|
||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(answer);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
};
|
@ -45,6 +45,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionMoving(factory);
|
||||
registerAggregateFunctionCategoricalIV(factory);
|
||||
registerAggregateFunctionAggThrow(factory);
|
||||
registerAggregateFunctionRankCorrelation(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -35,6 +35,7 @@ void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &)
|
||||
void registerAggregateFunctionMoving(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionAggThrow(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory &);
|
||||
|
||||
class AggregateFunctionCombinatorFactory;
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
|
@ -32,6 +32,7 @@ SRCS(
|
||||
AggregateFunctionNull.cpp
|
||||
AggregateFunctionOrFill.cpp
|
||||
AggregateFunctionQuantile.cpp
|
||||
AggregateFunctionRankCorrelation.cpp
|
||||
AggregateFunctionResample.cpp
|
||||
AggregateFunctionRetention.cpp
|
||||
AggregateFunctionSequenceMatch.cpp
|
||||
|
@ -0,0 +1,10 @@
|
||||
1
|
||||
1
|
||||
-1
|
||||
-1
|
||||
-0.037
|
||||
-0.037
|
||||
-0.108
|
||||
-0.108
|
||||
0.286
|
||||
0.286
|
@ -0,0 +1,29 @@
|
||||
CREATE DATABASE IF NOT EXISTS db_01455_rank_correlation;
|
||||
USE db_01455_rank_correlation;
|
||||
DROP TABLE IF EXISTS moons;
|
||||
DROP TABLE IF EXISTS circles;
|
||||
|
||||
SELECT '1';
|
||||
SELECT rankCorr(number, number) FROM numbers(100);
|
||||
|
||||
SELECT '-1';
|
||||
SELECT rankCorr(number, -1 * number) FROM numbers(100);
|
||||
|
||||
SELECT '-0.037';
|
||||
SELECT roundBankers(rankCorr(exp(number), sin(number)), 3) FROM numbers(100);
|
||||
|
||||
CREATE TABLE moons(a Float64, b Float64) Engine=Memory();
|
||||
INSERT INTO moons VALUES (1.230365,1.291454), (1.93851,0.6499), (1.574085,0.744109), (1.416457,1.41872), (1.90165,1.298199), (2.023844,1.142459), (1.828602,0.636404), (1.568649,1.157387), (1.968863,1.160039), (1.790198,0.860815), (1.238993,0.252486), (1.690338,0.573545), (1.678741,0.739649), (1.363346,0.514698), (1.924442,0.484331), (0.849071,0.585017), (1.859407,1.098124), (1.657176,1.314958), (1.085181,0.761741), (1.184481,0.639135), (1.59856,0.688384), (1.304818,1.212579), (1.913821,0.663551), (1.872619,0.510627), (1.29273,0.795267), (1.767669,0.892397), (1.790311,1.21813), (1.621893,1.229768), (1.525505,0.752643), (1.513535,1.016012), (1.120456,1.427238), (1.71505,0.716654), (1.394756,0.733629), (1.746027,1.422821), (1.5376,1.387397), (1.358968,0.575393), (1.941569,0.572639), (1.904995,0.966926), (1.967455,0.436449), (2.045535,0.582434), (1.365599,0.446582), (2.035874,0.468542), (1.419283,0.739308), (1.718267,0.895579), (1.285871,1.014628), (2.010657,1.631207), (1.78226,0.576882), (1.78274,0.727585), (1.454934,1.285701), (1.657208,0.581418);
|
||||
SELECT '-0.108';
|
||||
SELECT roundBankers(rankCorr(a, b), 3) from moons;
|
||||
|
||||
CREATE TABLE circles(a Float64, b Float64) Engine=Memory();
|
||||
INSERT INTO circles VALUES (1.20848,0.505643), (1.577706,1.726383), (1.945215,1.638926), (0.493616,0.792443), (0.827802,1.41133), (1.012179,1.654582), (1.815329,0.254426), (-0.068102,1.456476), (1.235432,1.565291), (1.269633,1.857153), (0.687433,1.24911), (0.131356,1.610389), (1.991372,0.204134), (1.678587,1.456911), (0.501133,0.68513), (0.924535,0.541514), (0.574115,0.340542), (-0.013384,1.17037), (0.917257,1.799431), (1.364786,0.396457), (1.931339,1.093935), (0.575076,0.427512), (2.084798,1.752707), (0.694029,0.257422), (-0.003821,0.160859), (0.037966,0.217695), (1.986527,1.249144), (1.864518,0.521483), (0.038928,0.175741), (1.855737,1.678827), (0.779503,0.963619), (0.035384,0.238397), (0.136108,0.128737), (0.0581,1.093712), (-0.012542,0.713137), (1.53441,0.447265), (0.198885,1.232961), (1.66781,0.259156), (1.478017,1.256315), (1.148358,1.659979), (0.340698,0.76793), (0.376184,0.578202), (0.251495,1.765917), (1.836389,1.75769), (1.573166,1.753943), (0.448309,0.965337), (1.704437,1.138451), (1.93234,1.723736), (1.412218,0.603027), (1.978789,0.938132);
|
||||
SELECT '0.286';
|
||||
SELECT roundBankers(rankCorr(a, b), 3) from circles;
|
||||
|
||||
DROP TABLE IF EXISTS moons;
|
||||
DROP TABLE IF EXISTS circles;
|
||||
DROP DATABASE IF EXISTS db_01455_rank_correlation;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user