mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Merge pull request #45985 from ClickHouse/fix-crash-in-regression
Fix crash in stochasticLinearRegression.
This commit is contained in:
commit
496cacf25e
@ -247,15 +247,8 @@ void Adam::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
|
||||
if (adam_rhs.average_gradient.empty())
|
||||
return;
|
||||
|
||||
if (average_gradient.empty())
|
||||
{
|
||||
if (!average_squared_gradient.empty() ||
|
||||
adam_rhs.average_gradient.size() != adam_rhs.average_squared_gradient.size())
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Average_gradient and average_squared_gradient must have same size");
|
||||
|
||||
average_gradient.resize(adam_rhs.average_gradient.size(), Float64{0.0});
|
||||
average_squared_gradient.resize(adam_rhs.average_squared_gradient.size(), Float64{0.0});
|
||||
}
|
||||
average_gradient.resize(adam_rhs.average_gradient.size(), Float64{0.0});
|
||||
average_squared_gradient.resize(adam_rhs.average_squared_gradient.size(), Float64{0.0});
|
||||
|
||||
for (size_t i = 0; i < average_gradient.size(); ++i)
|
||||
{
|
||||
@ -268,14 +261,8 @@ void Adam::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
|
||||
|
||||
void Adam::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
|
||||
{
|
||||
if (average_gradient.empty())
|
||||
{
|
||||
if (!average_squared_gradient.empty())
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Average_gradient and average_squared_gradient must have same size");
|
||||
|
||||
average_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
average_squared_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
average_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
average_squared_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
|
||||
for (size_t i = 0; i != average_gradient.size(); ++i)
|
||||
{
|
||||
@ -328,8 +315,7 @@ void Nesterov::write(WriteBuffer & buf) const
|
||||
void Nesterov::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
|
||||
{
|
||||
const auto & nesterov_rhs = static_cast<const Nesterov &>(rhs);
|
||||
if (accumulated_gradient.empty())
|
||||
accumulated_gradient.resize(nesterov_rhs.accumulated_gradient.size(), Float64{0.0});
|
||||
accumulated_gradient.resize(nesterov_rhs.accumulated_gradient.size(), Float64{0.0});
|
||||
|
||||
for (size_t i = 0; i < accumulated_gradient.size(); ++i)
|
||||
{
|
||||
@ -339,10 +325,7 @@ void Nesterov::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac
|
||||
|
||||
void Nesterov::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
|
||||
{
|
||||
if (accumulated_gradient.empty())
|
||||
{
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
|
||||
for (size_t i = 0; i < batch_gradient.size(); ++i)
|
||||
{
|
||||
@ -402,10 +385,7 @@ void Momentum::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac
|
||||
void Momentum::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
|
||||
{
|
||||
/// batch_size is already checked to be greater than 0
|
||||
if (accumulated_gradient.empty())
|
||||
{
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
|
||||
for (size_t i = 0; i < batch_gradient.size(); ++i)
|
||||
{
|
||||
|
@ -149,9 +149,11 @@ public:
|
||||
class Momentum : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Momentum() = default;
|
||||
|
||||
explicit Momentum(Float64 alpha_) : alpha(alpha_) {}
|
||||
explicit Momentum(size_t num_params, Float64 alpha_ = 0.1) : alpha(alpha_)
|
||||
{
|
||||
accumulated_gradient.resize(num_params + 1, 0);
|
||||
}
|
||||
|
||||
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
|
||||
|
||||
@ -170,9 +172,10 @@ private:
|
||||
class Nesterov : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Nesterov() = default;
|
||||
|
||||
explicit Nesterov(Float64 alpha_) : alpha(alpha_) {}
|
||||
explicit Nesterov(size_t num_params, Float64 alpha_ = 0.9) : alpha(alpha_)
|
||||
{
|
||||
accumulated_gradient.resize(num_params + 1, 0);
|
||||
}
|
||||
|
||||
void addToBatch(
|
||||
std::vector<Float64> & batch_gradient,
|
||||
@ -201,10 +204,14 @@ private:
|
||||
class Adam : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Adam()
|
||||
Adam(size_t num_params)
|
||||
{
|
||||
beta1_powered = beta1;
|
||||
beta2_powered = beta2;
|
||||
|
||||
|
||||
average_gradient.resize(num_params + 1, 0);
|
||||
average_squared_gradient.resize(num_params + 1, 0);
|
||||
}
|
||||
|
||||
void addToBatch(
|
||||
@ -338,11 +345,11 @@ public:
|
||||
if (weights_updater_name == "SGD")
|
||||
new_weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
else if (weights_updater_name == "Momentum")
|
||||
new_weights_updater = std::make_shared<Momentum>();
|
||||
new_weights_updater = std::make_shared<Momentum>(param_num);
|
||||
else if (weights_updater_name == "Nesterov")
|
||||
new_weights_updater = std::make_shared<Nesterov>();
|
||||
new_weights_updater = std::make_shared<Nesterov>(param_num);
|
||||
else if (weights_updater_name == "Adam")
|
||||
new_weights_updater = std::make_shared<Adam>();
|
||||
new_weights_updater = std::make_shared<Adam>(param_num);
|
||||
else
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Illegal name of weights updater (should have been checked earlier)");
|
||||
|
||||
|
14
tests/queries/0_stateless/02552_regression_crash.sql
Normal file
14
tests/queries/0_stateless/02552_regression_crash.sql
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user