mxnet
lr_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_CPP_LR_SCHEDULER_H_
26 #define MXNET_CPP_LR_SCHEDULER_H_
27 
28 #include "dmlc/logging.h"
29 
30 namespace mxnet {
31 namespace cpp {
32 
36 class LRScheduler {
37  public:
42  explicit LRScheduler(float base_lr = 0.01)
43  : base_lr_(base_lr) {}
48  void SetLR(const float lr) { base_lr_ = lr; }
52  virtual float GetLR(unsigned num_update) = 0;
56  virtual ~LRScheduler() {}
57 
58  protected:
59  float base_lr_;
60 };
61 
62 class FactorScheduler : public LRScheduler {
63  public:
64  explicit FactorScheduler(int step, float factor = 1, float stop_factor_lr = 1e-8)
65  : LRScheduler() {
66  step_ = step;
67  factor_ = factor;
68  stop_factor_lr_ = stop_factor_lr;
69  }
70 
71  float GetLR(unsigned num_update) override {
72  while (num_update > unsigned(count_ + step_)) {
73  count_ += step_;
74  base_lr_ *= factor_;
75  if (base_lr_ < stop_factor_lr_) {
76  base_lr_ = stop_factor_lr_;
77  LG << "Update[" << num_update << "]: now learning rate arrived at " \
78  << base_lr_ << ", will not change in the future";
79  } else {
80  LG << "Update[" << num_update << "]: Change learning rate to " << base_lr_;
81  }
82  }
83  return base_lr_;
84  }
85 
86  private:
87  int count_ = 0;
88  int step_;
89  float factor_;
90  float stop_factor_lr_;
91 };
92 
93 } // namespace cpp
94 } // namespace mxnet
95 
96 #endif // MXNET_CPP_LR_SCHEDULER_H_
virtual float GetLR(unsigned num_update)=0
get a new learning rate
namespace of mxnet
Definition: base.h:126
lr scheduler interface
Definition: lr_scheduler.h:36
FactorScheduler(int step, float factor=1, float stop_factor_lr=1e-8)
Definition: lr_scheduler.h:64
LRScheduler(float base_lr=0.01)
constructor
Definition: lr_scheduler.h:42
float GetLR(unsigned num_update) override
get a new learning rate
Definition: lr_scheduler.h:71
float base_lr_
Definition: lr_scheduler.h:59
void SetLR(const float lr)
set base lr
Definition: lr_scheduler.h:48
Definition: lr_scheduler.h:62
virtual ~LRScheduler()
destructor
Definition: lr_scheduler.h:56