mxnet
optimizer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_OPTIMIZER_H_
28 #define MXNET_CPP_OPTIMIZER_H_
29 
30 #include <map>
31 #include <vector>
32 #include <string>
33 #include <memory>
34 #include <functional>
35 #include "mxnet-cpp/base.h"
36 #include "dmlc/logging.h"
37 #include "mxnet-cpp/ndarray.h"
38 #include "mxnet-cpp/op_map.h"
39 #include "mxnet-cpp/lr_scheduler.h"
40 
41 namespace mxnet {
42 namespace cpp {
43 
47 class Optimizer {
48  public:
53  explicit Optimizer(unsigned begin_num_update);
58  virtual std::string GetType() const = 0;
62  virtual ~Optimizer();
69  template <typename T>
70  Optimizer *SetParam(const std::string &name, const T &value) {
71  std::string value_str;
72  std::stringstream ss;
73  ss << value;
74  ss >> value_str;
75 
76  params_[name] = value_str;
77  return this;
78  }
84  Optimizer *SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
85  CHECK(lrScheduler);
86  lrScheduler_ = std::move(lrScheduler);
87  lrScheduler_->SetLR(std::stof(params_["lr"]));
88  return this;
89  }
96  virtual void Update(int index, NDArray weight, NDArray grad) = 0;
97  // TODO(zhangcheng-qinyinghua)
98  // implement Update a list of arrays, maybe in the form of map
99  // void Update(int index, std::vector<NDArray> weights, std::vector<NDArray>
100  // grad, mx_float lr);
101 
106  std::string Serialize() const;
107 
108  protected:
109  std::map<std::string, std::string> params_;
110  static OpMap*& op_map();
111  const std::vector<const char*> GetParamKeys_() const;
112  const std::vector<const char*> GetParamValues_() const;
113  std::map<int, unsigned> count_;
115  unsigned UpdateCount_(int index);
116  float GetLR_(int index);
117  float GetWD_(int index);
118  virtual void CreateState_(int index, NDArray weight);
119  std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
120 };
121 
122 typedef std::function<Optimizer*()> OptimizerCreator;
123 
125  public:
126  static Optimizer* Find(const std::string& name);
127  static int __REGISTER__(const std::string& name, OptimizerCreator creator);
128  private:
129  static std::map<std::string, OptimizerCreator>& cmap();
130  OptimizerRegistry() = delete;
131  ~OptimizerRegistry() = delete;
132 };
133 #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType)\
134  OptimizerRegistry::__REGISTER__(#Name, [](){return new OptimizerType();})
135 
136 class SGDOptimizer : public Optimizer {
137  public:
138  explicit SGDOptimizer(unsigned begin_num_update = 0);
139  std::string GetType() const override;
140  void Update(int index, NDArray weight, NDArray grad) override;
141  private:
142  virtual ~SGDOptimizer();
143  void CreateState_(int index, NDArray weight) override;
144  std::map<int, NDArray*> states_;
145  AtomicSymbolCreator update_handle_;
146  AtomicSymbolCreator mom_update_handle_;
147 };
148 
149 class SignumOptimizer : public Optimizer {
150  public:
151  explicit SignumOptimizer(unsigned begin_num_update = 0);
152  std::string GetType() const override;
153  void Update(int index, NDArray weight, NDArray grad) override;
154  private:
155  virtual ~SignumOptimizer();
156  void CreateState_(int index, NDArray weight) override;
157  std::map<int, NDArray*> states_;
158  AtomicSymbolCreator update_handle_;
159  AtomicSymbolCreator mom_update_handle_;
160 };
161 
162 
163 class RMSPropOptimizer : public Optimizer {
164  public:
165  explicit RMSPropOptimizer(unsigned begin_num_update = 0);
166  std::string GetType() const override;
167  void Update(int index, NDArray weight, NDArray grad) override;
168  private:
169  virtual ~RMSPropOptimizer();
170  void CreateState_(int index, NDArray weight) override;
171  std::map<int, NDArray*> n_, g_, delta_;
172  AtomicSymbolCreator update_handle_;
173  AtomicSymbolCreator alex_update_handle_;
174 };
175 
176 class AdamOptimizer : public Optimizer {
177  public:
178  explicit AdamOptimizer(unsigned begin_num_update = 0);
179  std::string GetType() const override;
180  void Update(int index, NDArray weight, NDArray grad) override;
181  private:
182  virtual ~AdamOptimizer();
183  void CreateState_(int index, NDArray weight) override;
184  std::map<int, NDArray*> mean_;
185  std::map<int, NDArray*> var_;
186  AtomicSymbolCreator update_handle_;
187 };
188 
189 class AdaGradOptimizer : public Optimizer {
190  public:
191  explicit AdaGradOptimizer(unsigned begin_num_update = 0);
192  std::string GetType() const override;
193  void Update(int index, NDArray weight, NDArray grad) override;
194  private:
195  virtual ~AdaGradOptimizer();
196  void CreateState_(int index, NDArray weight) override;
197  std::map<int, NDArray*> history_;
198 };
199 
200 class AdaDeltaOptimizer : public Optimizer {
201  public:
202  explicit AdaDeltaOptimizer(unsigned begin_num_update = 0);
203  std::string GetType() const override;
204  void Update(int index, NDArray weight, NDArray grad) override;
205  private:
206  virtual ~AdaDeltaOptimizer();
207  void CreateState_(int index, NDArray weight) override;
208  std::map<int, NDArray*> acc_g_, acc_delta_;
209 };
210 
211 } // namespace cpp
212 } // namespace mxnet
213 
214 #endif // MXNET_CPP_OPTIMIZER_H_
Definition: optimizer.h:136
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
Definition: optimizer.h:176
definition of OpMap
unsigned UpdateCount_(int index)
namespace of mxnet
Definition: base.h:127
Optimizer(unsigned begin_num_update)
constructor
const std::vector< const char * > GetParamKeys_() const
virtual std::string GetType() const =0
get optimizer type
Scheduling learning rate.
unsigned begin_num_update_
Definition: optimizer.h:114
Definition: optimizer.h:149
unsigned num_update_
Definition: optimizer.h:114
Optimizer interface.
Definition: optimizer.h:47
std::map< int, unsigned > count_
Definition: optimizer.h:113
Definition: optimizer.h:200
Definition: optimizer.h:189
NDArray interface.
Definition: ndarray.h:121
float GetWD_(int index)
virtual void Update(int index, NDArray weight, NDArray grad)=0
Update a weight with gradient.
std::unique_ptr< LRScheduler > lrScheduler_
Definition: optimizer.h:119
Definition: optimizer.h:124
Definition: optimizer.h:163
std::map< std::string, std::string > params_
Definition: optimizer.h:109
virtual ~Optimizer()
destructor
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:69
const std::vector< const char * > GetParamValues_() const
Optimizer * SetParam(const std::string &name, const T &value)
set config parameters
Definition: optimizer.h:70
virtual void CreateState_(int index, NDArray weight)
std::string Serialize() const
Serialize the optimizer parameters to a string.
static OpMap *& op_map()
std::function< Optimizer *()> OptimizerCreator
Definition: optimizer.h:122
Optimizer * SetLRScheduler(std::unique_ptr< LRScheduler > lrScheduler)
Definition: optimizer.h:84
float GetLR_(int index)