mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 16:12:01 +00:00
Merge pull request #15874 from nikitamikhaylov/welch-t-test
Student and Welch t-test
This commit is contained in:
commit
58b4342998
339
base/glibc-compatibility/musl/lgammal.c
Normal file
339
base/glibc-compatibility/musl/lgammal.c
Normal file
@ -0,0 +1,339 @@
|
||||
/* origin: OpenBSD /usr/src/lib/libm/src/ld80/e_lgammal.c */
|
||||
/*
|
||||
* ====================================================
|
||||
* Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
|
||||
*
|
||||
* Developed at SunPro, a Sun Microsystems, Inc. business.
|
||||
* Permission to use, copy, modify, and distribute this
|
||||
* software is freely granted, provided that this notice
|
||||
* is preserved.
|
||||
* ====================================================
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) 2008 Stephen L. Moshier <steve@moshier.net>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
/* lgammal(x)
|
||||
* Reentrant version of the logarithm of the Gamma function
|
||||
* with user provide pointer for the sign of Gamma(x).
|
||||
*
|
||||
* Method:
|
||||
* 1. Argument Reduction for 0 < x <= 8
|
||||
* Since gamma(1+s)=s*gamma(s), for x in [0,8], we may
|
||||
* reduce x to a number in [1.5,2.5] by
|
||||
* lgamma(1+s) = log(s) + lgamma(s)
|
||||
* for example,
|
||||
* lgamma(7.3) = log(6.3) + lgamma(6.3)
|
||||
* = log(6.3*5.3) + lgamma(5.3)
|
||||
* = log(6.3*5.3*4.3*3.3*2.3) + lgamma(2.3)
|
||||
* 2. Polynomial approximation of lgamma around its
|
||||
* minimun ymin=1.461632144968362245 to maintain monotonicity.
|
||||
* On [ymin-0.23, ymin+0.27] (i.e., [1.23164,1.73163]), use
|
||||
* Let z = x-ymin;
|
||||
* lgamma(x) = -1.214862905358496078218 + z^2*poly(z)
|
||||
* 2. Rational approximation in the primary interval [2,3]
|
||||
* We use the following approximation:
|
||||
* s = x-2.0;
|
||||
* lgamma(x) = 0.5*s + s*P(s)/Q(s)
|
||||
* Our algorithms are based on the following observation
|
||||
*
|
||||
* zeta(2)-1 2 zeta(3)-1 3
|
||||
* lgamma(2+s) = s*(1-Euler) + --------- * s - --------- * s + ...
|
||||
* 2 3
|
||||
*
|
||||
* where Euler = 0.5771... is the Euler constant, which is very
|
||||
* close to 0.5.
|
||||
*
|
||||
* 3. For x>=8, we have
|
||||
* lgamma(x)~(x-0.5)log(x)-x+0.5*log(2pi)+1/(12x)-1/(360x**3)+....
|
||||
* (better formula:
|
||||
* lgamma(x)~(x-0.5)*(log(x)-1)-.5*(log(2pi)-1) + ...)
|
||||
* Let z = 1/x, then we approximation
|
||||
* f(z) = lgamma(x) - (x-0.5)(log(x)-1)
|
||||
* by
|
||||
* 3 5 11
|
||||
* w = w0 + w1*z + w2*z + w3*z + ... + w6*z
|
||||
*
|
||||
* 4. For negative x, since (G is gamma function)
|
||||
* -x*G(-x)*G(x) = pi/sin(pi*x),
|
||||
* we have
|
||||
* G(x) = pi/(sin(pi*x)*(-x)*G(-x))
|
||||
* since G(-x) is positive, sign(G(x)) = sign(sin(pi*x)) for x<0
|
||||
* Hence, for x<0, signgam = sign(sin(pi*x)) and
|
||||
* lgamma(x) = log(|Gamma(x)|)
|
||||
* = log(pi/(|x*sin(pi*x)|)) - lgamma(-x);
|
||||
* Note: one should avoid compute pi*(-x) directly in the
|
||||
* computation of sin(pi*(-x)).
|
||||
*
|
||||
* 5. Special Cases
|
||||
* lgamma(2+s) ~ s*(1-Euler) for tiny s
|
||||
* lgamma(1)=lgamma(2)=0
|
||||
* lgamma(x) ~ -log(x) for tiny x
|
||||
* lgamma(0) = lgamma(inf) = inf
|
||||
* lgamma(-integer) = +-inf
|
||||
*
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include "libm.h"
|
||||
|
||||
|
||||
#if LDBL_MANT_DIG == 53 && LDBL_MAX_EXP == 1024
|
||||
double lgamma_r(double x, int *sg);
|
||||
|
||||
long double lgammal_r(long double x, int *sg)
|
||||
{
|
||||
return lgamma_r(x, sg);
|
||||
}
|
||||
#elif LDBL_MANT_DIG == 64 && LDBL_MAX_EXP == 16384
|
||||
|
||||
static const long double pi = 3.14159265358979323846264L,
|
||||
|
||||
/* lgam(1+x) = 0.5 x + x a(x)/b(x)
|
||||
-0.268402099609375 <= x <= 0
|
||||
peak relative error 6.6e-22 */
|
||||
a0 = -6.343246574721079391729402781192128239938E2L,
|
||||
a1 = 1.856560238672465796768677717168371401378E3L,
|
||||
a2 = 2.404733102163746263689288466865843408429E3L,
|
||||
a3 = 8.804188795790383497379532868917517596322E2L,
|
||||
a4 = 1.135361354097447729740103745999661157426E2L,
|
||||
a5 = 3.766956539107615557608581581190400021285E0L,
|
||||
|
||||
b0 = 8.214973713960928795704317259806842490498E3L,
|
||||
b1 = 1.026343508841367384879065363925870888012E4L,
|
||||
b2 = 4.553337477045763320522762343132210919277E3L,
|
||||
b3 = 8.506975785032585797446253359230031874803E2L,
|
||||
b4 = 6.042447899703295436820744186992189445813E1L,
|
||||
/* b5 = 1.000000000000000000000000000000000000000E0 */
|
||||
|
||||
|
||||
tc = 1.4616321449683623412626595423257213284682E0L,
|
||||
tf = -1.2148629053584961146050602565082954242826E-1, /* double precision */
|
||||
/* tt = (tail of tf), i.e. tf + tt has extended precision. */
|
||||
tt = 3.3649914684731379602768989080467587736363E-18L,
|
||||
/* lgam ( 1.4616321449683623412626595423257213284682E0 ) =
|
||||
-1.2148629053584960809551455717769158215135617312999903886372437313313530E-1 */
|
||||
|
||||
/* lgam (x + tc) = tf + tt + x g(x)/h(x)
|
||||
-0.230003726999612341262659542325721328468 <= x
|
||||
<= 0.2699962730003876587373404576742786715318
|
||||
peak relative error 2.1e-21 */
|
||||
g0 = 3.645529916721223331888305293534095553827E-18L,
|
||||
g1 = 5.126654642791082497002594216163574795690E3L,
|
||||
g2 = 8.828603575854624811911631336122070070327E3L,
|
||||
g3 = 5.464186426932117031234820886525701595203E3L,
|
||||
g4 = 1.455427403530884193180776558102868592293E3L,
|
||||
g5 = 1.541735456969245924860307497029155838446E2L,
|
||||
g6 = 4.335498275274822298341872707453445815118E0L,
|
||||
|
||||
h0 = 1.059584930106085509696730443974495979641E4L,
|
||||
h1 = 2.147921653490043010629481226937850618860E4L,
|
||||
h2 = 1.643014770044524804175197151958100656728E4L,
|
||||
h3 = 5.869021995186925517228323497501767586078E3L,
|
||||
h4 = 9.764244777714344488787381271643502742293E2L,
|
||||
h5 = 6.442485441570592541741092969581997002349E1L,
|
||||
/* h6 = 1.000000000000000000000000000000000000000E0 */
|
||||
|
||||
|
||||
/* lgam (x+1) = -0.5 x + x u(x)/v(x)
|
||||
-0.100006103515625 <= x <= 0.231639862060546875
|
||||
peak relative error 1.3e-21 */
|
||||
u0 = -8.886217500092090678492242071879342025627E1L,
|
||||
u1 = 6.840109978129177639438792958320783599310E2L,
|
||||
u2 = 2.042626104514127267855588786511809932433E3L,
|
||||
u3 = 1.911723903442667422201651063009856064275E3L,
|
||||
u4 = 7.447065275665887457628865263491667767695E2L,
|
||||
u5 = 1.132256494121790736268471016493103952637E2L,
|
||||
u6 = 4.484398885516614191003094714505960972894E0L,
|
||||
|
||||
v0 = 1.150830924194461522996462401210374632929E3L,
|
||||
v1 = 3.399692260848747447377972081399737098610E3L,
|
||||
v2 = 3.786631705644460255229513563657226008015E3L,
|
||||
v3 = 1.966450123004478374557778781564114347876E3L,
|
||||
v4 = 4.741359068914069299837355438370682773122E2L,
|
||||
v5 = 4.508989649747184050907206782117647852364E1L,
|
||||
/* v6 = 1.000000000000000000000000000000000000000E0 */
|
||||
|
||||
|
||||
/* lgam (x+2) = .5 x + x s(x)/r(x)
|
||||
0 <= x <= 1
|
||||
peak relative error 7.2e-22 */
|
||||
s0 = 1.454726263410661942989109455292824853344E6L,
|
||||
s1 = -3.901428390086348447890408306153378922752E6L,
|
||||
s2 = -6.573568698209374121847873064292963089438E6L,
|
||||
s3 = -3.319055881485044417245964508099095984643E6L,
|
||||
s4 = -7.094891568758439227560184618114707107977E5L,
|
||||
s5 = -6.263426646464505837422314539808112478303E4L,
|
||||
s6 = -1.684926520999477529949915657519454051529E3L,
|
||||
|
||||
r0 = -1.883978160734303518163008696712983134698E7L,
|
||||
r1 = -2.815206082812062064902202753264922306830E7L,
|
||||
r2 = -1.600245495251915899081846093343626358398E7L,
|
||||
r3 = -4.310526301881305003489257052083370058799E6L,
|
||||
r4 = -5.563807682263923279438235987186184968542E5L,
|
||||
r5 = -3.027734654434169996032905158145259713083E4L,
|
||||
r6 = -4.501995652861105629217250715790764371267E2L,
|
||||
/* r6 = 1.000000000000000000000000000000000000000E0 */
|
||||
|
||||
|
||||
/* lgam(x) = ( x - 0.5 ) * log(x) - x + LS2PI + 1/x w(1/x^2)
|
||||
x >= 8
|
||||
Peak relative error 1.51e-21
|
||||
w0 = LS2PI - 0.5 */
|
||||
w0 = 4.189385332046727417803e-1L,
|
||||
w1 = 8.333333333333331447505E-2L,
|
||||
w2 = -2.777777777750349603440E-3L,
|
||||
w3 = 7.936507795855070755671E-4L,
|
||||
w4 = -5.952345851765688514613E-4L,
|
||||
w5 = 8.412723297322498080632E-4L,
|
||||
w6 = -1.880801938119376907179E-3L,
|
||||
w7 = 4.885026142432270781165E-3L;
|
||||
|
||||
|
||||
long double lgammal_r(long double x, int *sg) {
|
||||
long double t, y, z, nadj, p, p1, p2, q, r, w;
|
||||
union ldshape u = {x};
|
||||
uint32_t ix = (u.i.se & 0x7fffU)<<16 | u.i.m>>48;
|
||||
int sign = u.i.se >> 15;
|
||||
int i;
|
||||
|
||||
*sg = 1;
|
||||
|
||||
/* purge off +-inf, NaN, +-0, tiny and negative arguments */
|
||||
if (ix >= 0x7fff0000)
|
||||
return x * x;
|
||||
if (ix < 0x3fc08000) { /* |x|<2**-63, return -log(|x|) */
|
||||
if (sign) {
|
||||
*sg = -1;
|
||||
x = -x;
|
||||
}
|
||||
return -logl(x);
|
||||
}
|
||||
if (sign) {
|
||||
x = -x;
|
||||
t = sin(pi * x);
|
||||
if (t == 0.0)
|
||||
return 1.0 / (x-x); /* -integer */
|
||||
if (t > 0.0)
|
||||
*sg = -1;
|
||||
else
|
||||
t = -t;
|
||||
nadj = logl(pi / (t * x));
|
||||
}
|
||||
|
||||
/* purge off 1 and 2 (so the sign is ok with downward rounding) */
|
||||
if ((ix == 0x3fff8000 || ix == 0x40008000) && u.i.m == 0) {
|
||||
r = 0;
|
||||
} else if (ix < 0x40008000) { /* x < 2.0 */
|
||||
if (ix <= 0x3ffee666) { /* 8.99993896484375e-1 */
|
||||
/* lgamma(x) = lgamma(x+1) - log(x) */
|
||||
r = -logl(x);
|
||||
if (ix >= 0x3ffebb4a) { /* 7.31597900390625e-1 */
|
||||
y = x - 1.0;
|
||||
i = 0;
|
||||
} else if (ix >= 0x3ffced33) { /* 2.31639862060546875e-1 */
|
||||
y = x - (tc - 1.0);
|
||||
i = 1;
|
||||
} else { /* x < 0.23 */
|
||||
y = x;
|
||||
i = 2;
|
||||
}
|
||||
} else {
|
||||
r = 0.0;
|
||||
if (ix >= 0x3fffdda6) { /* 1.73162841796875 */
|
||||
/* [1.7316,2] */
|
||||
y = x - 2.0;
|
||||
i = 0;
|
||||
} else if (ix >= 0x3fff9da6) { /* 1.23162841796875 */
|
||||
/* [1.23,1.73] */
|
||||
y = x - tc;
|
||||
i = 1;
|
||||
} else {
|
||||
/* [0.9, 1.23] */
|
||||
y = x - 1.0;
|
||||
i = 2;
|
||||
}
|
||||
}
|
||||
switch (i) {
|
||||
case 0:
|
||||
p1 = a0 + y * (a1 + y * (a2 + y * (a3 + y * (a4 + y * a5))));
|
||||
p2 = b0 + y * (b1 + y * (b2 + y * (b3 + y * (b4 + y))));
|
||||
r += 0.5 * y + y * p1/p2;
|
||||
break;
|
||||
case 1:
|
||||
p1 = g0 + y * (g1 + y * (g2 + y * (g3 + y * (g4 + y * (g5 + y * g6)))));
|
||||
p2 = h0 + y * (h1 + y * (h2 + y * (h3 + y * (h4 + y * (h5 + y)))));
|
||||
p = tt + y * p1/p2;
|
||||
r += (tf + p);
|
||||
break;
|
||||
case 2:
|
||||
p1 = y * (u0 + y * (u1 + y * (u2 + y * (u3 + y * (u4 + y * (u5 + y * u6))))));
|
||||
p2 = v0 + y * (v1 + y * (v2 + y * (v3 + y * (v4 + y * (v5 + y)))));
|
||||
r += (-0.5 * y + p1 / p2);
|
||||
}
|
||||
} else if (ix < 0x40028000) { /* 8.0 */
|
||||
/* x < 8.0 */
|
||||
i = (int)x;
|
||||
y = x - (double)i;
|
||||
p = y * (s0 + y * (s1 + y * (s2 + y * (s3 + y * (s4 + y * (s5 + y * s6))))));
|
||||
q = r0 + y * (r1 + y * (r2 + y * (r3 + y * (r4 + y * (r5 + y * (r6 + y))))));
|
||||
r = 0.5 * y + p / q;
|
||||
z = 1.0;
|
||||
/* lgamma(1+s) = log(s) + lgamma(s) */
|
||||
switch (i) {
|
||||
case 7:
|
||||
z *= (y + 6.0); /* FALLTHRU */
|
||||
case 6:
|
||||
z *= (y + 5.0); /* FALLTHRU */
|
||||
case 5:
|
||||
z *= (y + 4.0); /* FALLTHRU */
|
||||
case 4:
|
||||
z *= (y + 3.0); /* FALLTHRU */
|
||||
case 3:
|
||||
z *= (y + 2.0); /* FALLTHRU */
|
||||
r += logl(z);
|
||||
break;
|
||||
}
|
||||
} else if (ix < 0x40418000) { /* 2^66 */
|
||||
/* 8.0 <= x < 2**66 */
|
||||
t = logl(x);
|
||||
z = 1.0 / x;
|
||||
y = z * z;
|
||||
w = w0 + z * (w1 + y * (w2 + y * (w3 + y * (w4 + y * (w5 + y * (w6 + y * w7))))));
|
||||
r = (x - 0.5) * (t - 1.0) + w;
|
||||
} else /* 2**66 <= x <= inf */
|
||||
r = x * (logl(x) - 1.0);
|
||||
if (sign)
|
||||
r = nadj - r;
|
||||
return r;
|
||||
}
|
||||
#elif LDBL_MANT_DIG == 113 && LDBL_MAX_EXP == 16384
|
||||
// TODO: broken implementation to make things compile
|
||||
double lgamma_r(double x, int *sg);
|
||||
|
||||
long double lgammal_r(long double x, int *sg)
|
||||
{
|
||||
return lgamma_r(x, sg);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
int signgam_lgammal;
|
||||
|
||||
long double lgammal(long double x)
|
||||
{
|
||||
return lgammal_r(x, &signgam_lgammal);
|
||||
}
|
||||
|
@ -53,6 +53,7 @@ RUN apt-get update \
|
||||
ninja-build \
|
||||
psmisc \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-lxml \
|
||||
python3-requests \
|
||||
python3-termcolor \
|
||||
@ -62,6 +63,8 @@ RUN apt-get update \
|
||||
unixodbc \
|
||||
--yes --no-install-recommends
|
||||
|
||||
RUN pip3 install numpy scipy pandas
|
||||
|
||||
# This symlink required by gcc to find lld compiler
|
||||
RUN ln -s /usr/bin/lld-${LLVM_VERSION} /usr/bin/ld.lld
|
||||
|
||||
|
@ -268,7 +268,10 @@ TESTS_TO_SKIP=(
|
||||
00974_query_profiler
|
||||
|
||||
# Look at DistributedFilesToInsert, so cannot run in parallel.
|
||||
01457_DistributedFilesToInsert
|
||||
01460_DistributedFilesToInsert
|
||||
|
||||
# Require python libraries like scipy, pandas and numpy
|
||||
01322_ttest_scipy
|
||||
)
|
||||
|
||||
time clickhouse-test -j 8 --order=random --no-long --testname --shard --zookeeper --skip "${TESTS_TO_SKIP[@]}" 2>&1 | ts '%Y-%m-%d %H:%M:%S' | tee "$FASTTEST_OUTPUT/test_log.txt"
|
||||
|
@ -16,6 +16,7 @@ RUN apt-get update -y \
|
||||
python3-lxml \
|
||||
python3-requests \
|
||||
python3-termcolor \
|
||||
python3-pip \
|
||||
qemu-user-static \
|
||||
sudo \
|
||||
telnet \
|
||||
@ -23,6 +24,8 @@ RUN apt-get update -y \
|
||||
unixodbc \
|
||||
wget
|
||||
|
||||
RUN pip3 install numpy scipy pandas
|
||||
|
||||
RUN mkdir -p /tmp/clickhouse-odbc-tmp \
|
||||
&& wget -nv -O - ${odbc_driver_url} | tar --strip-components=1 -xz -C /tmp/clickhouse-odbc-tmp \
|
||||
&& cp /tmp/clickhouse-odbc-tmp/lib64/*.so /usr/local/lib/ \
|
||||
|
@ -58,6 +58,7 @@ RUN apt-get --allow-unauthenticated update -y \
|
||||
python3-lxml \
|
||||
python3-requests \
|
||||
python3-termcolor \
|
||||
python3-pip \
|
||||
qemu-user-static \
|
||||
sudo \
|
||||
telnet \
|
||||
@ -68,6 +69,8 @@ RUN apt-get --allow-unauthenticated update -y \
|
||||
wget \
|
||||
zlib1g-dev
|
||||
|
||||
RUN pip3 install numpy scipy pandas
|
||||
|
||||
RUN mkdir -p /tmp/clickhouse-odbc-tmp \
|
||||
&& wget -nv -O - ${odbc_driver_url} | tar --strip-components=1 -xz -C /tmp/clickhouse-odbc-tmp \
|
||||
&& cp /tmp/clickhouse-odbc-tmp/lib64/*.so /usr/local/lib/ \
|
||||
|
52
src/AggregateFunctions/AggregateFunctionStudentTTest.cpp
Normal file
52
src/AggregateFunctions/AggregateFunctionStudentTTest.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionStudentTTest.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||
|
||||
|
||||
// the return type is boolean (we use UInt8 as we do not have boolean in clickhouse)
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionStudentTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
assertNoParameters(name, parameters);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
|
||||
if (isDecimal(argument_types[0]) || isDecimal(argument_types[1]))
|
||||
{
|
||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
else
|
||||
{
|
||||
res.reset(createWithTwoNumericTypes<AggregateFunctionStudentTTest>(*argument_types[0], *argument_types[1], argument_types));
|
||||
}
|
||||
|
||||
if (!res)
|
||||
{
|
||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
}
|
253
src/AggregateFunctions/AggregateFunctionStudentTTest.h
Normal file
253
src/AggregateFunctions/AggregateFunctionStudentTTest.h
Normal file
@ -0,0 +1,253 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Core/Types.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <typename X = Float64, typename Y = Float64>
|
||||
struct AggregateFunctionStudentTTestData final
|
||||
{
|
||||
size_t size_x = 0;
|
||||
size_t size_y = 0;
|
||||
X sum_x = static_cast<X>(0);
|
||||
Y sum_y = static_cast<Y>(0);
|
||||
X square_sum_x = static_cast<X>(0);
|
||||
Y square_sum_y = static_cast<Y>(0);
|
||||
Float64 mean_x = static_cast<Float64>(0);
|
||||
Float64 mean_y = static_cast<Float64>(0);
|
||||
|
||||
void add(X x, Y y)
|
||||
{
|
||||
sum_x += x;
|
||||
sum_y += y;
|
||||
size_x++;
|
||||
size_y++;
|
||||
mean_x = static_cast<Float64>(sum_x) / size_x;
|
||||
mean_y = static_cast<Float64>(sum_y) / size_y;
|
||||
square_sum_x += x * x;
|
||||
square_sum_y += y * y;
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionStudentTTestData &other)
|
||||
{
|
||||
sum_x += other.sum_x;
|
||||
sum_y += other.sum_y;
|
||||
size_x += other.size_x;
|
||||
size_y += other.size_y;
|
||||
mean_x = static_cast<Float64>(sum_x) / size_x;
|
||||
mean_y = static_cast<Float64>(sum_y) / size_y;
|
||||
square_sum_x += other.square_sum_x;
|
||||
square_sum_y += other.square_sum_y;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer &buf) const
|
||||
{
|
||||
writeBinary(mean_x, buf);
|
||||
writeBinary(mean_y, buf);
|
||||
writeBinary(sum_x, buf);
|
||||
writeBinary(sum_y, buf);
|
||||
writeBinary(square_sum_x, buf);
|
||||
writeBinary(square_sum_y, buf);
|
||||
writeBinary(size_x, buf);
|
||||
writeBinary(size_y, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer &buf)
|
||||
{
|
||||
readBinary(mean_x, buf);
|
||||
readBinary(mean_y, buf);
|
||||
readBinary(sum_x, buf);
|
||||
readBinary(sum_y, buf);
|
||||
readBinary(square_sum_x, buf);
|
||||
readBinary(square_sum_y, buf);
|
||||
readBinary(size_x, buf);
|
||||
readBinary(size_y, buf);
|
||||
}
|
||||
|
||||
size_t getSizeY() const
|
||||
{
|
||||
return size_y;
|
||||
}
|
||||
|
||||
size_t getSizeX() const
|
||||
{
|
||||
return size_x;
|
||||
}
|
||||
|
||||
Float64 getSSquared() const
|
||||
{
|
||||
/// The original formulae looks like
|
||||
/// \frac{\sum_{i = 1}^{n_x}{(x_i - \bar{x}) ^ 2} + \sum_{i = 1}^{n_y}{(y_i - \bar{y}) ^ 2}}{n_x + n_y - 2}
|
||||
/// But we made some mathematical transformations not to store original sequences.
|
||||
/// Also we dropped sqrt, because later it will be squared later.
|
||||
const Float64 all_x = square_sum_x + size_x * std::pow(mean_x, 2) - 2 * mean_x * sum_x;
|
||||
const Float64 all_y = square_sum_y + size_y * std::pow(mean_y, 2) - 2 * mean_y * sum_y;
|
||||
return static_cast<Float64>(all_x + all_y) / (size_x + size_y - 2);
|
||||
}
|
||||
|
||||
|
||||
Float64 getTStatisticSquared() const
|
||||
{
|
||||
return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared();
|
||||
}
|
||||
|
||||
Float64 getTStatistic() const
|
||||
{
|
||||
return (mean_x - mean_y) / std::sqrt(getStandartErrorSquared());
|
||||
}
|
||||
|
||||
Float64 getStandartErrorSquared() const
|
||||
{
|
||||
if (size_x == 0 || size_y == 0)
|
||||
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
return getSSquared() * (1.0 / static_cast<Float64>(size_x) + 1.0 / static_cast<Float64>(size_y));
|
||||
}
|
||||
|
||||
Float64 getDegreesOfFreedom() const
|
||||
{
|
||||
return static_cast<Float64>(size_x + size_y - 2);
|
||||
}
|
||||
|
||||
static Float64 integrateSimpson(Float64 a, Float64 b, std::function<Float64(Float64)> func)
|
||||
{
|
||||
const size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b)));
|
||||
const long double h = (b - a) / iterations;
|
||||
Float64 sum_odds = 0.0;
|
||||
for (size_t i = 1; i < iterations; i += 2)
|
||||
sum_odds += func(a + i * h);
|
||||
Float64 sum_evens = 0.0;
|
||||
for (size_t i = 2; i < iterations; i += 2)
|
||||
sum_evens += func(a + i * h);
|
||||
return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3;
|
||||
}
|
||||
|
||||
Float64 getPValue() const
|
||||
{
|
||||
const Float64 v = getDegreesOfFreedom();
|
||||
const Float64 t = getTStatisticSquared();
|
||||
auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); };
|
||||
Float64 numenator = integrateSimpson(0, v / (t + v), f);
|
||||
Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5));
|
||||
return numenator / denominator;
|
||||
}
|
||||
|
||||
std::pair<Float64, Float64> getResult() const
|
||||
{
|
||||
return std::make_pair(getTStatistic(), getPValue());
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns tuple of (t-statistic, p-value)
|
||||
/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf
|
||||
template <typename X = Float64, typename Y = Float64>
|
||||
class AggregateFunctionStudentTTest :
|
||||
public IAggregateFunctionDataHelper<AggregateFunctionStudentTTestData<X, Y>,AggregateFunctionStudentTTest<X, Y>>
|
||||
{
|
||||
|
||||
public:
|
||||
AggregateFunctionStudentTTest(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionStudentTTestData<X, Y>, AggregateFunctionStudentTTest<X, Y>> ({arguments}, {})
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "studentTTest";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"t-statistic",
|
||||
"p-value"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto col_x = assert_cast<const ColumnVector<X> *>(columns[0]);
|
||||
auto col_y = assert_cast<const ColumnVector<Y> *>(columns[1]);
|
||||
|
||||
X x = col_x->getData()[row_num];
|
||||
Y y = col_y->getData()[row_num];
|
||||
|
||||
this->data(place).add(x, y);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override
|
||||
{
|
||||
size_t size_x = this->data(place).getSizeX();
|
||||
size_t size_y = this->data(place).getSizeY();
|
||||
|
||||
if (size_x < 2 || size_y < 2)
|
||||
{
|
||||
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
Float64 t_statistic = 0.0;
|
||||
Float64 p_value = 0.0;
|
||||
std::tie(t_statistic, p_value) = this->data(place).getResult();
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(t_statistic);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
};
|
49
src/AggregateFunctions/AggregateFunctionWelchTTest.cpp
Normal file
49
src/AggregateFunctions/AggregateFunctionWelchTTest.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionWelchTTest.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionWelchTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
assertNoParameters(name, parameters);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
|
||||
if (isDecimal(argument_types[0]) || isDecimal(argument_types[1]))
|
||||
{
|
||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
else
|
||||
{
|
||||
res.reset(createWithTwoNumericTypes<AggregateFunctionWelchTTest>(*argument_types[0], *argument_types[1], argument_types));
|
||||
}
|
||||
|
||||
if (!res)
|
||||
{
|
||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
}
|
264
src/AggregateFunctions/AggregateFunctionWelchTTest.h
Normal file
264
src/AggregateFunctions/AggregateFunctionWelchTTest.h
Normal file
@ -0,0 +1,264 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Core/Types.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <typename X = Float64, typename Y = Float64>
|
||||
struct AggregateFunctionWelchTTestData final
|
||||
{
|
||||
size_t size_x = 0;
|
||||
size_t size_y = 0;
|
||||
X sum_x = static_cast<X>(0);
|
||||
Y sum_y = static_cast<Y>(0);
|
||||
X square_sum_x = static_cast<X>(0);
|
||||
Y square_sum_y = static_cast<Y>(0);
|
||||
Float64 mean_x = static_cast<Float64>(0);
|
||||
Float64 mean_y = static_cast<Float64>(0);
|
||||
|
||||
void add(X x, Y y)
|
||||
{
|
||||
sum_x += x;
|
||||
sum_y += y;
|
||||
size_x++;
|
||||
size_y++;
|
||||
mean_x = static_cast<Float64>(sum_x) / size_x;
|
||||
mean_y = static_cast<Float64>(sum_y) / size_y;
|
||||
square_sum_x += x * x;
|
||||
square_sum_y += y * y;
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionWelchTTestData &other)
|
||||
{
|
||||
sum_x += other.sum_x;
|
||||
sum_y += other.sum_y;
|
||||
size_x += other.size_x;
|
||||
size_y += other.size_y;
|
||||
mean_x = static_cast<Float64>(sum_x) / size_x;
|
||||
mean_y = static_cast<Float64>(sum_y) / size_y;
|
||||
square_sum_x += other.square_sum_x;
|
||||
square_sum_y += other.square_sum_y;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer &buf) const
|
||||
{
|
||||
writeBinary(mean_x, buf);
|
||||
writeBinary(mean_y, buf);
|
||||
writeBinary(sum_x, buf);
|
||||
writeBinary(sum_y, buf);
|
||||
writeBinary(square_sum_x, buf);
|
||||
writeBinary(square_sum_y, buf);
|
||||
writeBinary(size_x, buf);
|
||||
writeBinary(size_y, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer &buf)
|
||||
{
|
||||
readBinary(mean_x, buf);
|
||||
readBinary(mean_y, buf);
|
||||
readBinary(sum_x, buf);
|
||||
readBinary(sum_y, buf);
|
||||
readBinary(square_sum_x, buf);
|
||||
readBinary(square_sum_y, buf);
|
||||
readBinary(size_x, buf);
|
||||
readBinary(size_y, buf);
|
||||
}
|
||||
|
||||
size_t getSizeY() const
|
||||
{
|
||||
return size_y;
|
||||
}
|
||||
|
||||
size_t getSizeX() const
|
||||
{
|
||||
return size_x;
|
||||
}
|
||||
|
||||
Float64 getSxSquared() const
|
||||
{
|
||||
/// The original formulae looks like \frac{1}{size_x - 1} \sum_{i = 1}^{size_x}{(x_i - \bar{x}) ^ 2}
|
||||
/// But we made some mathematical transformations not to store original sequences.
|
||||
/// Also we dropped sqrt, because later it will be squared later.
|
||||
return static_cast<Float64>(square_sum_x + size_x * std::pow(mean_x, 2) - 2 * mean_x * sum_x) / (size_x - 1);
|
||||
}
|
||||
|
||||
Float64 getSySquared() const
|
||||
{
|
||||
/// The original formulae looks like \frac{1}{size_y - 1} \sum_{i = 1}^{size_y}{(y_i - \bar{y}) ^ 2}
|
||||
/// But we made some mathematical transformations not to store original sequences.
|
||||
/// Also we dropped sqrt, because later it will be squared later.
|
||||
return static_cast<Float64>(square_sum_y + size_y * std::pow(mean_y, 2) - 2 * mean_y * sum_y) / (size_y - 1);
|
||||
}
|
||||
|
||||
Float64 getTStatisticSquared() const
|
||||
{
|
||||
if (size_x == 0 || size_y == 0)
|
||||
{
|
||||
throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y);
|
||||
}
|
||||
|
||||
Float64 getTStatistic() const
|
||||
{
|
||||
if (size_x == 0 || size_y == 0)
|
||||
{
|
||||
throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
return (mean_x - mean_y) / std::sqrt(getSxSquared() / size_x + getSySquared() / size_y);
|
||||
}
|
||||
|
||||
Float64 getDegreesOfFreedom() const
|
||||
{
|
||||
auto sx = getSxSquared();
|
||||
auto sy = getSySquared();
|
||||
Float64 numerator = std::pow(sx / size_x + sy / size_y, 2);
|
||||
Float64 denominator_first = std::pow(sx, 2) / (std::pow(size_x, 2) * (size_x - 1));
|
||||
Float64 denominator_second = std::pow(sy, 2) / (std::pow(size_y, 2) * (size_y - 1));
|
||||
return numerator / (denominator_first + denominator_second);
|
||||
}
|
||||
|
||||
static Float64 integrateSimpson(Float64 a, Float64 b, std::function<Float64(Float64)> func)
|
||||
{
|
||||
size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b)));
|
||||
double h = (b - a) / iterations;
|
||||
Float64 sum_odds = 0.0;
|
||||
for (size_t i = 1; i < iterations; i += 2)
|
||||
sum_odds += func(a + i * h);
|
||||
Float64 sum_evens = 0.0;
|
||||
for (size_t i = 2; i < iterations; i += 2)
|
||||
sum_evens += func(a + i * h);
|
||||
return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3;
|
||||
}
|
||||
|
||||
Float64 getPValue() const
|
||||
{
|
||||
const Float64 v = getDegreesOfFreedom();
|
||||
const Float64 t = getTStatisticSquared();
|
||||
auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); };
|
||||
Float64 numenator = integrateSimpson(0, v / (t + v), f);
|
||||
Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5));
|
||||
return numenator / denominator;
|
||||
}
|
||||
|
||||
std::pair<Float64, Float64> getResult() const
|
||||
{
|
||||
return std::make_pair(getTStatistic(), getPValue());
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns tuple of (t-statistic, p-value)
|
||||
/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf
|
||||
template <typename X = Float64, typename Y = Float64>
|
||||
class AggregateFunctionWelchTTest :
|
||||
public IAggregateFunctionDataHelper<AggregateFunctionWelchTTestData<X, Y>,AggregateFunctionWelchTTest<X, Y>>
|
||||
{
|
||||
|
||||
public:
|
||||
AggregateFunctionWelchTTest(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionWelchTTestData<X, Y>, AggregateFunctionWelchTTest<X, Y>> ({arguments}, {})
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "welchTTest";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"t-statistic",
|
||||
"p-value"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto col_x = assert_cast<const ColumnVector<X> *>(columns[0]);
|
||||
auto col_y = assert_cast<const ColumnVector<Y> *>(columns[1]);
|
||||
|
||||
X x = col_x->getData()[row_num];
|
||||
Y y = col_y->getData()[row_num];
|
||||
|
||||
this->data(place).add(x, y);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override
|
||||
{
|
||||
size_t size_x = this->data(place).getSizeX();
|
||||
size_t size_y = this->data(place).getSizeY();
|
||||
|
||||
if (size_x < 2 || size_y < 2)
|
||||
{
|
||||
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
Float64 t_statistic = 0.0;
|
||||
Float64 p_value = 0.0;
|
||||
std::tie(t_statistic, p_value) = this->data(place).getResult();
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(t_statistic);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
};
|
@ -45,6 +45,8 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionMoving(factory);
|
||||
registerAggregateFunctionCategoricalIV(factory);
|
||||
registerAggregateFunctionAggThrow(factory);
|
||||
registerAggregateFunctionWelchTTest(factory);
|
||||
registerAggregateFunctionStudentTTest(factory);
|
||||
registerAggregateFunctionRankCorrelation(factory);
|
||||
}
|
||||
|
||||
|
@ -35,6 +35,8 @@ void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &)
|
||||
void registerAggregateFunctionMoving(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionAggThrow(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory &);
|
||||
|
||||
class AggregateFunctionCombinatorFactory;
|
||||
|
@ -42,6 +42,7 @@ SRCS(
|
||||
AggregateFunctionState.cpp
|
||||
AggregateFunctionStatistics.cpp
|
||||
AggregateFunctionStatisticsSimple.cpp
|
||||
AggregateFunctionStudentTTest.cpp
|
||||
AggregateFunctionSum.cpp
|
||||
AggregateFunctionSumMap.cpp
|
||||
AggregateFunctionTimeSeriesGroupSum.cpp
|
||||
@ -49,6 +50,7 @@ SRCS(
|
||||
AggregateFunctionUniqCombined.cpp
|
||||
AggregateFunctionUniq.cpp
|
||||
AggregateFunctionUniqUpTo.cpp
|
||||
AggregateFunctionWelchTTest.cpp
|
||||
AggregateFunctionWindowFunnel.cpp
|
||||
parseAggregateFunctionParameters.cpp
|
||||
registerAggregateFunctions.cpp
|
||||
|
4
tests/queries/0_stateless/01322_student_ttest.reference
Normal file
4
tests/queries/0_stateless/01322_student_ttest.reference
Normal file
@ -0,0 +1,4 @@
|
||||
-2.610898982580138 0.00916587538237954
|
||||
-2.610898982580134 0.0091658753823792
|
||||
-28.740781574102936 7.667329672103986e-133
|
||||
-28.74078157410298 0
|
19
tests/queries/0_stateless/01322_student_ttest.sql
Normal file
19
tests/queries/0_stateless/01322_student_ttest.sql
Normal file
File diff suppressed because one or more lines are too long
108
tests/queries/0_stateless/01322_ttest_scipy.python
Normal file
108
tests/queries/0_stateless/01322_ttest_scipy.python
Normal file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import io
|
||||
import sys
|
||||
import requests
|
||||
import time
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
|
||||
CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1')
|
||||
CLICKHOUSE_PORT_HTTP = os.environ.get('CLICKHOUSE_PORT_HTTP', '8123')
|
||||
CLICKHOUSE_SERVER_URL_STR = 'http://' + ':'.join(str(s) for s in [CLICKHOUSE_HOST, CLICKHOUSE_PORT_HTTP]) + "/"
|
||||
|
||||
class ClickHouseClient:
|
||||
def __init__(self, host = CLICKHOUSE_SERVER_URL_STR):
|
||||
self.host = host
|
||||
|
||||
def query(self, query, connection_timeout = 1500):
|
||||
NUMBER_OF_TRIES = 30
|
||||
DELAY = 10
|
||||
|
||||
for i in range(NUMBER_OF_TRIES):
|
||||
r = requests.post(
|
||||
self.host,
|
||||
params = {'timeout_before_checking_execution_speed': 120, 'max_execution_time': 6000},
|
||||
timeout = connection_timeout,
|
||||
data = query)
|
||||
if r.status_code == 200:
|
||||
return r.text
|
||||
else:
|
||||
print('ATTENTION: try #%d failed' % i)
|
||||
if i != (NUMBER_OF_TRIES-1):
|
||||
print(query)
|
||||
print(r.text)
|
||||
time.sleep(DELAY*(i+1))
|
||||
else:
|
||||
raise ValueError(r.text)
|
||||
|
||||
def query_return_df(self, query, connection_timeout = 1500):
|
||||
data = self.query(query, connection_timeout)
|
||||
df = pd.read_csv(io.StringIO(data), sep = '\t')
|
||||
return df
|
||||
|
||||
def query_with_data(self, query, content):
|
||||
content = content.encode('utf-8')
|
||||
r = requests.post(self.host, data=content)
|
||||
result = r.text
|
||||
if r.status_code == 200:
|
||||
return result
|
||||
else:
|
||||
raise ValueError(r.text)
|
||||
|
||||
def test_and_check(name, a, b, t_stat, p_value):
|
||||
client = ClickHouseClient()
|
||||
client.query("DROP TABLE IF EXISTS ttest;")
|
||||
client.query("CREATE TABLE ttest (left Float64, right Float64) ENGINE = Memory;");
|
||||
client.query("INSERT INTO ttest VALUES {};".format(", ".join(['({},{})'.format(i, j) for i,j in zip(a, b)])))
|
||||
|
||||
real = client.query_return_df(
|
||||
"SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) +
|
||||
"roundBankers({}(left, right).2, 16) as p_value ".format(name) +
|
||||
"FROM ttest FORMAT TabSeparatedWithNames;")
|
||||
real_t_stat = real['t_stat'][0]
|
||||
real_p_value = real['p_value'][0]
|
||||
assert(abs(real_t_stat - np.float64(t_stat) < 1e-2)), "clickhouse_t_stat {}, scipy_t_stat {}".format(real_t_stat, t_stat)
|
||||
assert(abs(real_p_value - np.float64(p_value)) < 1e-2), "clickhouse_p_value {}, scipy_p_value {}".format(real_p_value, p_value)
|
||||
client.query("DROP TABLE IF EXISTS ttest;")
|
||||
|
||||
|
||||
def test_student():
|
||||
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 5)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 5)
|
||||
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||
|
||||
rvs1 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 5)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 5)
|
||||
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||
|
||||
|
||||
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=5, scale=1,size=65536), 5)
|
||||
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||
|
||||
def test_welch():
|
||||
rvs1 = np.round(stats.norm.rvs(loc=1, scale=15,size=500), 5)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 5)
|
||||
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||
|
||||
rvs1 = np.round(stats.norm.rvs(loc=0, scale=7,size=500), 5)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=0, scale=3,size=500), 5)
|
||||
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||
|
||||
|
||||
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=5, scale=1,size=65536), 5)
|
||||
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_student()
|
||||
test_welch()
|
||||
print("Ok.")
|
1
tests/queries/0_stateless/01322_ttest_scipy.reference
Normal file
1
tests/queries/0_stateless/01322_ttest_scipy.reference
Normal file
@ -0,0 +1 @@
|
||||
Ok.
|
8
tests/queries/0_stateless/01322_ttest_scipy.sh
Executable file
8
tests/queries/0_stateless/01322_ttest_scipy.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
. "$CURDIR"/../shell_config.sh
|
||||
|
||||
# We should have correct env vars from shell_config.sh to run this test
|
||||
|
||||
python3 "$CURDIR"/01322_ttest_scipy.python
|
10
tests/queries/0_stateless/01322_welch_ttest.reference
Normal file
10
tests/queries/0_stateless/01322_welch_ttest.reference
Normal file
@ -0,0 +1,10 @@
|
||||
0.021378001462867
|
||||
0.0213780014628671
|
||||
0.090773324285671
|
||||
0.0907733242891952
|
||||
0.00339907162713746
|
||||
0.0033990715715539
|
||||
-0.5028215369186904 0.6152361677168877
|
||||
-0.5028215369187079 0.6152361677170834
|
||||
14.971190998235835 5.898143508382202e-44
|
||||
14.971190998235837 0
|
37
tests/queries/0_stateless/01322_welch_ttest.sql
Normal file
37
tests/queries/0_stateless/01322_welch_ttest.sql
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user