Merge pull request #45985 from ClickHouse/fix-crash-in-regression

Fix crash in stochasticLinearRegression.
This commit is contained in:
Alexey Milovidov 2023-02-04 03:01:46 +01:00 committed by GitHub
commit 496cacf25e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 36 deletions

View File

@ -247,15 +247,8 @@ void Adam::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
if (adam_rhs.average_gradient.empty())
return;
if (average_gradient.empty())
{
if (!average_squared_gradient.empty() ||
adam_rhs.average_gradient.size() != adam_rhs.average_squared_gradient.size())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Average_gradient and average_squared_gradient must have same size");
average_gradient.resize(adam_rhs.average_gradient.size(), Float64{0.0});
average_squared_gradient.resize(adam_rhs.average_squared_gradient.size(), Float64{0.0});
}
average_gradient.resize(adam_rhs.average_gradient.size(), Float64{0.0});
average_squared_gradient.resize(adam_rhs.average_squared_gradient.size(), Float64{0.0});
for (size_t i = 0; i < average_gradient.size(); ++i)
{
@ -268,14 +261,8 @@ void Adam::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
void Adam::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
{
if (average_gradient.empty())
{
if (!average_squared_gradient.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Average_gradient and average_squared_gradient must have same size");
average_gradient.resize(batch_gradient.size(), Float64{0.0});
average_squared_gradient.resize(batch_gradient.size(), Float64{0.0});
}
average_gradient.resize(batch_gradient.size(), Float64{0.0});
average_squared_gradient.resize(batch_gradient.size(), Float64{0.0});
for (size_t i = 0; i != average_gradient.size(); ++i)
{
@ -328,8 +315,7 @@ void Nesterov::write(WriteBuffer & buf) const
void Nesterov::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
{
const auto & nesterov_rhs = static_cast<const Nesterov &>(rhs);
if (accumulated_gradient.empty())
accumulated_gradient.resize(nesterov_rhs.accumulated_gradient.size(), Float64{0.0});
accumulated_gradient.resize(nesterov_rhs.accumulated_gradient.size(), Float64{0.0});
for (size_t i = 0; i < accumulated_gradient.size(); ++i)
{
@ -339,10 +325,7 @@ void Nesterov::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac
void Nesterov::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
{
if (accumulated_gradient.empty())
{
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
}
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
for (size_t i = 0; i < batch_gradient.size(); ++i)
{
@ -402,10 +385,7 @@ void Momentum::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac
void Momentum::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
{
/// batch_size is already checked to be greater than 0
if (accumulated_gradient.empty())
{
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
}
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
for (size_t i = 0; i < batch_gradient.size(); ++i)
{

View File

@ -149,9 +149,11 @@ public:
class Momentum : public IWeightsUpdater
{
public:
Momentum() = default;
explicit Momentum(Float64 alpha_) : alpha(alpha_) {}
explicit Momentum(size_t num_params, Float64 alpha_ = 0.1) : alpha(alpha_)
{
accumulated_gradient.resize(num_params + 1, 0);
}
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
@ -170,9 +172,10 @@ private:
class Nesterov : public IWeightsUpdater
{
public:
Nesterov() = default;
explicit Nesterov(Float64 alpha_) : alpha(alpha_) {}
explicit Nesterov(size_t num_params, Float64 alpha_ = 0.9) : alpha(alpha_)
{
accumulated_gradient.resize(num_params + 1, 0);
}
void addToBatch(
std::vector<Float64> & batch_gradient,
@ -201,10 +204,14 @@ private:
class Adam : public IWeightsUpdater
{
public:
Adam()
Adam(size_t num_params)
{
beta1_powered = beta1;
beta2_powered = beta2;
average_gradient.resize(num_params + 1, 0);
average_squared_gradient.resize(num_params + 1, 0);
}
void addToBatch(
@ -338,11 +345,11 @@ public:
if (weights_updater_name == "SGD")
new_weights_updater = std::make_shared<StochasticGradientDescent>();
else if (weights_updater_name == "Momentum")
new_weights_updater = std::make_shared<Momentum>();
new_weights_updater = std::make_shared<Momentum>(param_num);
else if (weights_updater_name == "Nesterov")
new_weights_updater = std::make_shared<Nesterov>();
new_weights_updater = std::make_shared<Nesterov>(param_num);
else if (weights_updater_name == "Adam")
new_weights_updater = std::make_shared<Adam>();
new_weights_updater = std::make_shared<Adam>(param_num);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Illegal name of weights updater (should have been checked earlier)");

File diff suppressed because one or more lines are too long