mxnet
lr_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_LR_SCHEDULER_H_
27 #define MXNET_CPP_LR_SCHEDULER_H_
28 
29 #include "dmlc/logging.h"
30 
31 namespace mxnet {
32 namespace cpp {
33 
37 class LRScheduler {
38  public:
43  explicit LRScheduler(float base_lr = 0.01)
44  : base_lr_(base_lr) {}
49  void SetLR(const float lr) { base_lr_ = lr; }
53  virtual float GetLR(unsigned num_update) = 0;
57  virtual ~LRScheduler() {}
58 
59  protected:
60  float base_lr_;
61 };
62 
63 class FactorScheduler : public LRScheduler {
64  public:
65  explicit FactorScheduler(int step, float factor = 1, float stop_factor_lr = 1e-8)
66  : LRScheduler() {
67  step_ = step;
68  factor_ = factor;
69  stop_factor_lr_ = stop_factor_lr;
70  }
71 
72  float GetLR(unsigned num_update) override {
73  while (num_update > unsigned(count_ + step_)) {
74  count_ += step_;
75  base_lr_ *= factor_;
76  if (base_lr_ < stop_factor_lr_) {
77  base_lr_ = stop_factor_lr_;
78  LG << "Update[" << num_update << "]: now learning rate arrived at " \
79  << base_lr_ << ", will not change in the future";
80  } else {
81  LG << "Update[" << num_update << "]: Change learning rate to " << base_lr_;
82  }
83  }
84  return base_lr_;
85  }
86 
87  private:
88  int count_ = 0;
89  int step_;
90  float factor_;
91  float stop_factor_lr_;
92 };
93 
94 } // namespace cpp
95 } // namespace mxnet
96 
97 #endif // MXNET_CPP_LR_SCHEDULER_H_
virtual float GetLR(unsigned num_update)=0
get a new learning rate
namespace of mxnet
Definition: base.h:127
lr scheduler interface
Definition: lr_scheduler.h:37
FactorScheduler(int step, float factor=1, float stop_factor_lr=1e-8)
Definition: lr_scheduler.h:65
LRScheduler(float base_lr=0.01)
constructor
Definition: lr_scheduler.h:43
float GetLR(unsigned num_update) override
get a new learning rate
Definition: lr_scheduler.h:72
float base_lr_
Definition: lr_scheduler.h:60
void SetLR(const float lr)
set base lr
Definition: lr_scheduler.h:49
Definition: lr_scheduler.h:63
virtual ~LRScheduler()
destructor
Definition: lr_scheduler.h:57