ClickHouse/tests/queries/0_stateless/00947_ml_test.sql

51 lines
39 KiB
MySQL
Raw Normal View History

2019-06-07 16:02:24 +00:00
DROP TABLE IF EXISTS defaults;
CREATE TABLE IF NOT EXISTS defaults
2019-02-26 08:12:16 +00:00
(
param1 Float64,
param2 Float64,
target Float64,
predict1 Float64,
predict2 Float64
) ENGINE = Memory;
2019-06-07 16:02:24 +00:00
insert into defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0), (-1.099, 2.755, -3.060, 20.0, 40.0), (1.090, 2.945, -2.346, 20.0, 40.0), (0.305, 2.179, -1.205, 20.0, 40.0), (-0.925, 0.702, 1.134, 20.0, 40.0), (3.178, -1.316, 7.221, 20.0, 40.0), (-2.756, -0.473, 2.569, 20.0, 40.0), (3.665, 2.303, 0.226, 20.0, 40.0), (1.662, 1.951, -0.070, 20.0, 40.0), (2.869, 0.593, 3.249, 20.0, 40.0), (0.818, -0.593, 4.594, 20.0, 40.0), (-1.917, 0.916, 0.209, 20.0, 40.0), (2.706, 1.523, 1.307, 20.0, 40.0), (0.219, 2.162, -1.214, 20.0, 40.0), (-4.510, 1.376, -2.007, 20.0, 40.0), (4.284, -0.515, 6.173, 20.0, 40.0), (-1.101, 2.810, -3.170, 20.0, 40.0), (-1.810, -1.117, 4.329, 20.0, 40.0), (0.055, 1.115, 0.797, 20.0, 40.0), (-2.178, 2.904, -3.898, 20.0, 40.0), (-3.494, -1.814, 4.882, 20.0, 40.0), (3.027, 0.476, 3.562, 20.0, 40.0), (-1.434, 1.151, -0.018, 20.0, 40.0), (1.180, 0.992, 1.606, 20.0, 40.0), (0.015, 0.971, 1.067, 20.0, 40.0), (-0.511, -0.875, 4.495, 20.0, 40.0), (0.961, 2.348, -1.216, 20.0, 40.0), (-2.279, 0.038, 1.785, 20.0, 40.0), (-1.568, -0.248, 2.712, 20.0, 40.0), (-0.496, 0.366, 2.020, 20.0, 40.0), (1.177, -1.401, 6.390, 20.0, 40.0), (2.882, -1.442, 7.325, 20.0, 40.0), (-1.066, 1.817, -1.167, 20.0, 40.0), (-2.144, 2.791, -3.655, 20.0, 40.0), (-4.370, 2.228, -3.642, 20.0, 40.0), (3.996, 2.775, -0.553, 20.0, 40.0), (0.289, 2.055, -0.965, 20.0, 40.0), (-0.588, -1.601, 5.908, 20.0, 40.0), (-1.801, 0.417, 1.265, 20.0, 40.0), (4.375, -1.499, 8.186, 20.0, 40.0), (-2.618, 0.038, 1.615, 20.0, 40.0), (3.616, -0.833, 6.475, 20.0, 40.0), (-4.045, -1.558, 4.094, 20.0, 40.0), (-3.962, 0.636, -0.253, 20.0, 40.0), (3.505, 2.625, -0.497, 20.0, 40.0), (3.029, -0.523, 5.560, 20.0, 40.0), (-3.520, -0.474, 2.188, 20.0, 40.0), (2.430, -1.469, 7.154, 20.0, 40.0), (1.547, -1.654, 7.082, 20.0, 40.0), (-1.370, 0.575, 1.165, 20.0, 40.0), (-1.869, -1.555, 5.176, 20.0, 40.0), (3.536, 2.841, -0.913, 20.0, 40.0), (-3.810, 1.220, -1.344, 20.0, 40.0), (-1.971, 1.462, -0.910, 20.0, 40.0), (-0.243, 0.167, 2.545, 20.0, 40.0), (-1.403, 2.645, -2.991, 20.0, 40.0), (0.532, -0.114, 3.494, 20.0, 40.0), (-1.678, 0.975, 0.212, 20.0, 40.0), (-0.656, 2.140, -1.609, 20.0, 40.0), (1.743, 2.631, -1.390, 20.0, 40.0), (2.586, 2.943, -1.593, 20.0, 40.0), (-0.512, 2.969, -3.195, 20.0, 40.0), (2.283, -0.100, 4.342, 20.0, 40.0), (-4.293, 0.872, -0.890, 20.0, 40.0), (3.411, 1.300, 2.106, 20.0, 40.0), (-0.281, 2.951, -3.042, 20.0, 40.0), (-4.442, 0.384, 0.012, 20.0, 40.0), (1.194, 1.746, 0.104, 20.0, 40.0), (-1.152, 1.862, -1.300, 20.0, 40.0), (1.362, -1.341, 6.363, 20.0, 40.0), (-4.488, 2.618, -4.481, 20.0, 40.0), (3.419, -0.564, 5.837, 20.0, 40.0), (-3.392, 0.396, 0.512, 20.0, 40.0), (-1.629, -0.909, 4.003, 20.0, 40.0), (4.447, -1.088, 7.399, 20.0, 40.0), (-1.232, 1.699, -1.014, 20.0, 40.0), (-1.286, -0.609, 3.575, 20.0, 40.0), (2.437, 2.796, -1.374, 20.0, 40.0), (-4.864, 1.989, -3.410, 20.0, 40.0), (-1.716, -1.399, 4.940, 20.0, 40.0), (-3.084, 1.858, -2.259, 20.0, 40.0), (2.828, -0.319, 5.053, 20.0, 40.0), (-1.226, 2.586, -2.786, 20.0, 40.0), (2.456, 0.092, 4.044, 20.0, 40.0), (-0.989, 2.375, -2.245, 20.0, 40.0), (3.268, 0.935, 2.765, 20.0, 40.0), (-4.128, -1.995, 4.927, 20.0, 40.0), (-1.083, 2.197, -1.935, 20.0, 40.0), (-3.471, -1.198, 3.660, 20.0, 40.0), (4.617, -1.136, 7.579, 20.0, 40.0), (2.054, -1.675, 7.378, 20.0, 40.0), (4.106, 2.326, 0.402, 20.0, 40.0), (1.558, 0.310, 3.158, 20.0, 40.0), (0.792, 0.900, 1.596, 20.0, 40.0), (-3.229, 0.300, 0.785, 20.0, 40.0), (3.787, -0.793, 6.479, 20.0, 40.0), (1.786, 2.288, -0.684, 20.0, 40.0), (2.643, 0.223, 3.875, 20.0, 40.0), (-3.592, 2.122, -3.040, 20.0, 40.0), (4.519, -1.760, 8.779, 20.0, 40.0), (3.221, 2.255, 0.101, 20.0, 40.0), (4.151, 1.788, 1.500, 20.0, 40.0), (-1.033, -1.195, 4.874, 20.0, 40.0), (-1.636, -1.037, 4.257, 20.0, 40.0), (-3.548, 1.911, -2.596, 20.0, 40.0), (4.829, -0.293, 6.001, 20.0, 40.0), (-4.684, -1.664, 3.986, 20.0, 40.0), (4.531, -0.503, 6.271, 20.0, 40.0), (-3.503, -1.606, 4.460, 20.0, 40.0), (-2.036, -1.522, 5.027, 20.0, 40.0), (-0.473, -0.617, 3.997, 20.0, 40.0), (-1.5
2019-02-26 08:12:16 +00:00
2019-06-07 16:02:24 +00:00
DROP TABLE IF EXISTS model;
create table model engine = Memory as select stochasticLinearRegressionState(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as state from defaults;
2019-02-26 08:12:16 +00:00
2019-07-16 21:11:10 +00:00
select ans > -67.01 and ans < -66.9 from
2019-06-07 16:02:24 +00:00
(with (select state from model) as model select evalMLMethod(model, predict1, predict2) as ans from defaults limit 1);
-- Check that returned weights are close to real
select ans > 0.49 and ans < 0.51 from
2019-06-07 16:02:24 +00:00
(select stochasticLinearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2)[1] as ans from defaults);
select ans > -2.01 and ans < -1.99 from
2019-06-07 16:02:24 +00:00
(select stochasticLinearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2)[2] as ans from defaults);
select ans > 2.99 and ans < 3.01 from
2019-06-07 16:02:24 +00:00
(select stochasticLinearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2)[3] as ans from defaults);
-- Check GROUP BY
2019-06-07 16:02:24 +00:00
DROP TABLE IF EXISTS grouptest;
CREATE TABLE IF NOT EXISTS grouptest
(
user_id UInt32,
p1 Float64,
p2 Float64,
target Float64
) ENGINE = Memory;
2019-06-07 16:02:24 +00:00
INSERT INTO grouptest VALUES
(1, 1.732, 3.653, 11.422), (1, 2.150, 2.103, 7.609), (1, 0.061, 3.310, 7.052), (1, 1.030, 3.671, 10.075), (1, 1.879, 0.578, 2.492), (1, 0.922, 2.552, 6.499), (1, 1.145, -0.095, -0.993), (1, 1.920, 0.373, 1.959), (1, 0.458, 0.094, -1.801), (1, -0.118, 3.273, 6.582), (1, 2.667, 1.472, 6.752), (1, -0.387, -0.529, -5.360), (1, 2.219, 1.790, 6.810), (1, -0.754, 2.139, 1.908), (1, -0.446, -0.668, -5.896), (1, 1.729, 0.914, 3.199), (1, 2.908, -0.420, 1.556), (1, 1.645, 3.581, 11.034), (1, 0.358, -0.950, -5.136), (1, -0.467, 2.339, 3.084), (1, 3.629, 2.959, 13.135), (1, 2.393, 0.926, 4.563), (1, -0.945, 0.281, -4.047), (1, 3.688, -0.570, 2.667), (1, 3.016, 1.775, 8.356), (1, 2.571, 0.139, 2.559), (1, 2.999, 0.956, 5.866), (1, 1.754, -0.809, -1.920), (1, 3.943, 0.382, 6.030), (1, -0.970, 2.315, 2.004), (1, 1.503, 0.790, 2.376), (1, -0.775, 2.563, 3.139), (1, 1.211, 0.113, -0.240), (1, 3.058, 0.977, 6.048), (1, 2.729, 1.634, 7.360), (1, 0.307, 2.759, 5.893), (1, 3.272, 0.181, 4.089), (1, 1.192, 1.963, 5.273), (1, 0.931, 1.447, 3.203), (1, 3.835, 3.447, 15.011), (1, 0.709, 0.008, -1.559), (1, 3.155, -0.676, 1.283), (1, 2.342, 1.047, 4.824), (1, 2.059, 1.262, 4.903), (1, 2.797, 0.855, 5.159), (1, 0.387, 0.645, -0.292), (1, 1.418, 0.408, 1.060), (1, 2.719, -0.826, -0.039), (1, 2.735, 3.736, 13.678), (1, 0.205, 0.777, -0.260), (1, 3.117, 2.063, 9.424), (1, 0.601, 0.178, -1.263), (1, 0.064, 0.157, -2.401), (1, 3.104, -0.455, 1.842), (1, -0.253, 0.672, -1.490), (1, 2.592, -0.408, 0.961), (1, -0.909, 1.314, -0.878), (1, 0.625, 2.594, 6.031), (1, 2.749, -0.210, 1.869), (1, -0.469, 1.532, 0.657), (1, 1.954, 1.827, 6.388), (1, -0.528, 1.136, -0.647), (1, 0.802, -0.583, -3.146), (1, -0.176, 1.584, 1.400), (1, -0.705, -0.785, -6.766), (1, 1.660, 2.365, 7.416), (1, 2.278, 3.977, 13.485), (1, 2.846, 3.845, 14.229), (1, 3.588, -0.401, 2.974), (1, 3.525, 3.831, 15.542), (1, 0.191, 3.312, 7.318), (1, 2.615, -0.287, 1.370), (1, 2.701, -0.446, 1.064), (1, 2.065, -0.556, -0.538), (1, 2.572, 3.618, 12.997), (1, 3.743, -0.708, 2.362), (1, 3.734, 2.319, 11.425), (1, 3.768, 2.777, 12.866), (1, 3.203, 0.958, 6.280), (1, 1.512, 2.635, 7.927), (1, 2.194, 2.323, 8.356), (1, -0.726, 2.729, 3.735), (1, 0.020, 1.704, 2.152), (1, 2.173, 2.856, 9.912), (1, 3.124, 1.705, 8.364), (1, -0.834, 2.142, 1.759), (1, -0.702, 3.024, 4.666), (1, 1.393, 0.583, 1.535), (1, 2.136, 3.770, 12.581), (1, -0.445, 0.991, -0.917), (1, 0.244, -0.835, -5.016), (1, 2.789, 0.691, 4.652), (1, 0.246, 2.661, 5.475), (1, 3.793, 2.671, 12.601), (1, 1.645, -0.973, -2.627), (1, 2.405, 1.842, 7.336), (1, 3.221, 3.109, 12.769), (1, -0.638, 3.220, 5.385), (1, 1.836, 3.025, 9.748), (1, -0.660, 1.818, 1.133), (1, 0.901, 0.981, 1.744), (1, -0.236, 3.087, 5.789), (1, 1.744, 3.864, 12.078), (1, -0.166, 3.186, 6.226), (1, 3.536, -0.090, 3.803), (1, 3.284, 2.026, 9.648), (1, 1.327, 2.822, 8.119), (1, -0.709, 0.105, -4.104), (1, 0.509, -0.989, -4.949), (1, 0.180, -0.934, -5.440), (1, 3.522, 1.374, 8.168), (1, 1.497, -0.764, -2.297), (1, 1.696, 2.364, 7.482), (1, -0.202, -0.032, -3.500), (1, 3.109, -0.138, 2.804), (1, -0.238, 2.992, 5.501), (1, 1.639, 1.634, 5.181), (1, 1.919, 0.341, 1.859), (1, -0.563, 1.750, 1.124), (1, 0.886, 3.589, 9.539), (1, 3.619, 3.020, 13.299), (1, 1.703, -0.493, -1.073), (1, 2.364, 3.764, 13.022), (1, 1.820, 1.854, 6.201), (1, 1.437, -0.765, -2.421), (1, 1.396, 0.959, 2.668), (1, 2.608, 2.032, 8.312), (1, 0.333, -0.040, -2.455), (1, 3.441, 0.824, 6.355), (1, 1.303, 2.767, 7.908), (1, 1.359, 2.404, 6.932), (1, 0.674, 0.241, -0.930), (1, 2.708, -0.077, 2.183), (1, 3.821, 3.215, 14.287), (1, 3.316, 1.591, 8.404), (1, -0.848, 1.145, -1.259), (1, 3.455, 3.081, 13.153), (1, 2.568, 0.259, 2.914), (1, 2.866, 2.636, 10.642), (1, 2.776, -0.309, 1.626), (1, 2.087, 0.619, 3.031), (1, 1.682, 1.201, 3.967), (1, 3.800, 2.600, 12.399), (1, 3.344, -0.780, 1.347), (1, 1.053, -0.817, -3.346), (1, 0.805, 3.085, 7.865), (1, 0.173, 0.069, -2.449), (1, 2.018, 1.309, 4.964), (1, 3.713, 3.804, 15.838), (1, 3.805, -0.063, 4.421), (1, 3.587, 2.854, 12.738), (1, 2.426, -0.179, 1.315), (1, 0.535, 0.572, -0.
SELECT ANS[1] > -1.1 AND ANS[1] < -0.9 AND ANS[2] > 5.9 AND ANS[2] < 6.1 AND ANS[3] > 9.9 AND ANS[3] < 10.1 FROM
2019-07-16 21:11:10 +00:00
(SELECT stochasticLinearRegression(0.05, 0, 1, 'SGD')(target, p1, p2) AS ANS FROM grouptest GROUP BY user_id LIMIT 0, 1);
SELECT ANS[1] > 1.9 AND ANS[1] < 2.1 AND ANS[2] > 2.9 AND ANS[2] < 3.1 AND ANS[3] > -3.1 AND ANS[3] < -2.9 FROM
2019-07-16 21:11:10 +00:00
(SELECT stochasticLinearRegression(0.05, 0, 1, 'SGD')(target, p1, p2) AS ANS FROM grouptest GROUP BY user_id LIMIT 1, 1);
2019-06-07 16:02:24 +00:00
DROP TABLE defaults;
DROP TABLE model;
DROP TABLE grouptest;