mxnet
initializer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_INITIALIZER_H_
27 #define MXNET_CPP_INITIALIZER_H_
28 
29 #include <cmath>
30 #include <string>
31 #include <vector>
32 #include <random>
33 #include "mxnet-cpp/ndarray.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 
38 class Initializer {
39  public:
40  static bool StringStartWith(const std::string& name,
41  const std::string& check_str) {
42  return (name.size() >= check_str.size() &&
43  name.substr(0, check_str.size()) == check_str);
44  }
45  static bool StringEndWith(const std::string& name,
46  const std::string& check_str) {
47  return (name.size() >= check_str.size() &&
48  name.substr(name.size() - check_str.size(), check_str.size()) ==
49  check_str);
50  }
51  virtual void operator()(const std::string& name, NDArray* arr) {
52  if (StringStartWith(name, "upsampling")) {
53  InitBilinear(arr);
54  } else if (StringEndWith(name, "bias")) {
55  InitBias(arr);
56  } else if (StringEndWith(name, "gamma")) {
57  InitGamma(arr);
58  } else if (StringEndWith(name, "beta")) {
59  InitBeta(arr);
60  } else if (StringEndWith(name, "weight")) {
61  InitWeight(arr);
62  } else if (StringEndWith(name, "moving_mean")) {
63  InitZero(arr);
64  } else if (StringEndWith(name, "moving_var")) {
65  InitOne(arr);
66  } else if (StringEndWith(name, "moving_inv_var")) {
67  InitZero(arr);
68  } else if (StringEndWith(name, "moving_avg")) {
69  InitZero(arr);
70  } else if (StringEndWith(name, "min")) {
71  InitZero(arr);
72  } else if (StringEndWith(name, "max")) {
73  InitOne(arr);
74  } else if (StringEndWith(name, "weight_quantize")) {
76  } else if (StringEndWith(name, "bias_quantize")) {
77  InitQuantizedBias(arr);
78  } else {
79  InitDefault(arr);
80  }
81  }
82 
83  protected:
84  virtual void InitBilinear(NDArray* arr) {
85  Shape shape(arr->GetShape());
86  std::vector<float> weight(shape.Size(), 0);
87  int f = std::ceil(shape[3] / 2.0);
88  float c = (2 * f - 1 - f % 2) / (2. * f);
89  for (size_t i = 0; i < shape.Size(); ++i) {
90  int x = i % shape[3];
91  int y = (i / shape[3]) % shape[2];
92  weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
93  }
94  (*arr).SyncCopyFromCPU(weight);
95  }
96  virtual void InitZero(NDArray* arr) { (*arr) = 0.0f; }
97  virtual void InitOne(NDArray* arr) { (*arr) = 1.0f; }
98  virtual void InitBias(NDArray* arr) { (*arr) = 0.0f; }
99  virtual void InitGamma(NDArray* arr) { (*arr) = 1.0f; }
100  virtual void InitBeta(NDArray* arr) { (*arr) = 0.0f; }
101  virtual void InitWeight(NDArray* arr) {}
102  virtual void InitQuantizedWeight(NDArray* arr) {
103  std::default_random_engine generator;
104  std::uniform_int_distribution<int32_t> _val(-127, 127);
105  (*arr) = _val(generator);
106  }
107  virtual void InitQuantizedBias(NDArray* arr) {
108  (*arr) = 0;
109  }
110  virtual void InitDefault(NDArray* arr) {}
111 };
112 
113 class Constant : public Initializer {
114  public:
115  explicit Constant(float value)
116  : value(value) {}
117  void operator()(const std::string &name, NDArray *arr) override {
118  (*arr) = value;
119  }
120  protected:
121  float value;
122 };
123 
124 class Zero : public Constant {
125  public:
126  Zero(): Constant(0.0f) {}
127 };
128 
129 class One : public Constant {
130  public:
131  One(): Constant(1.0f) {}
132 };
133 
134 class Uniform : public Initializer {
135  public:
136  explicit Uniform(float scale)
137  : Uniform(-scale, scale) {}
138  Uniform(float begin, float end)
139  : begin(begin), end(end) {}
140  void operator()(const std::string &name, NDArray *arr) override {
141  if (StringEndWith(name, "weight_quantize")) {
142  InitQuantizedWeight(arr);
143  return;
144  }
145  if (StringEndWith(name, "bias_quantize")) {
146  InitQuantizedBias(arr);
147  return;
148  }
149  NDArray::SampleUniform(begin, end, arr);
150  }
151  protected:
152  float begin, end;
153 };
154 
155 class Normal : public Initializer {
156  public:
157  Normal(float mu, float sigma)
158  : mu(mu), sigma(sigma) {}
159  void operator()(const std::string &name, NDArray *arr) override {
160  if (StringEndWith(name, "weight_quantize")) {
161  InitQuantizedWeight(arr);
162  return;
163  }
164  if (StringEndWith(name, "bias_quantize")) {
165  InitQuantizedBias(arr);
166  return;
167  }
168  NDArray::SampleGaussian(mu, sigma, arr);
169  }
170  protected:
171  float mu, sigma;
172 };
173 
174 class Bilinear : public Initializer {
175  public:
176  Bilinear() {}
177  void operator()(const std::string &name, NDArray *arr) override {
178  if (StringEndWith(name, "weight_quantize")) {
179  InitQuantizedWeight(arr);
180  return;
181  }
182  if (StringEndWith(name, "bias_quantize")) {
183  InitQuantizedBias(arr);
184  return;
185  }
186  InitBilinear(arr);
187  }
188 };
189 
190 class Xavier : public Initializer {
191  public:
192  enum RandType {
194  uniform
195  } rand_type;
196  enum FactorType {
198  in,
199  out
200  } factor_type;
201  float magnitude;
202  Xavier(RandType rand_type = gaussian, FactorType factor_type = avg,
203  float magnitude = 3)
204  : rand_type(rand_type), factor_type(factor_type), magnitude(magnitude) {}
205 
206  void operator()(const std::string &name, NDArray* arr) override {
207  if (StringEndWith(name, "weight_quantize")) {
208  InitQuantizedWeight(arr);
209  return;
210  }
211  if (StringEndWith(name, "bias_quantize")) {
212  InitQuantizedBias(arr);
213  return;
214  }
215 
216  Shape shape(arr->GetShape());
217  float hw_scale = 1.0f;
218  if (shape.ndim() > 2) {
219  for (size_t i = 2; i < shape.ndim(); ++i) {
220  hw_scale *= shape[i];
221  }
222  }
223  float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale;
224  float factor = 1.0f;
225  switch (factor_type) {
226  case avg:
227  factor = (fan_in + fan_out) / 2.0;
228  break;
229  case in:
230  factor = fan_in;
231  break;
232  case out:
233  factor = fan_out;
234  }
235  float scale = std::sqrt(magnitude / factor);
236  switch (rand_type) {
237  case uniform:
238  NDArray::SampleUniform(-scale, scale, arr);
239  break;
240  case gaussian:
241  NDArray::SampleGaussian(0, scale, arr);
242  break;
243  }
244  }
245 };
246 
247 class MSRAPrelu : public Xavier {
248  public:
249  explicit MSRAPrelu(FactorType factor_type = avg, float slope = 0.25f)
250  : Xavier(gaussian, factor_type, 2. / (1 + slope * slope)) {}
251 };
252 
253 } // namespace cpp
254 } // namespace mxnet
255 
256 #endif // MXNET_CPP_INITIALIZER_H_
static bool StringStartWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:40
Uniform(float scale)
Definition: initializer.h:136
static bool StringEndWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:45
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:177
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:159
Definition: initializer.h:174
Definition: initializer.h:190
namespace of mxnet
Definition: api_registry.h:33
virtual void InitQuantizedWeight(NDArray *arr)
Definition: initializer.h:102
std::vector< mx_uint > GetShape() const
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
Definition: initializer.h:113
MSRAPrelu(FactorType factor_type=avg, float slope=0.25f)
Definition: initializer.h:249
Definition: initializer.h:134
FactorType
Definition: initializer.h:196
Xavier(RandType rand_type=gaussian, FactorType factor_type=avg, float magnitude=3)
Definition: initializer.h:202
float value
Definition: initializer.h:121
Uniform(float begin, float end)
Definition: initializer.h:138
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:140
virtual void InitOne(NDArray *arr)
Definition: initializer.h:97
virtual void InitGamma(NDArray *arr)
Definition: initializer.h:99
Definition: initializer.h:129
virtual void InitBias(NDArray *arr)
Definition: initializer.h:98
One()
Definition: initializer.h:131
virtual void InitZero(NDArray *arr)
Definition: initializer.h:96
NDArray interface.
Definition: ndarray.h:120
Definition: initializer.h:193
Definition: initializer.h:124
Definition: initializer.h:197
float sigma
Definition: initializer.h:171
virtual void InitBeta(NDArray *arr)
Definition: initializer.h:100
Definition: initializer.h:198
virtual void InitWeight(NDArray *arr)
Definition: initializer.h:101
Definition: initializer.h:155
Definition: initializer.h:247
float end
Definition: initializer.h:152
RandType
Definition: initializer.h:192
Zero()
Definition: initializer.h:126
static void SampleUniform(mx_float begin, mx_float end, NDArray *out)
Sample uniform distribution for each elements of out.
Bilinear()
Definition: initializer.h:176
virtual void InitBilinear(NDArray *arr)
Definition: initializer.h:84
Normal(float mu, float sigma)
Definition: initializer.h:157
Constant(float value)
Definition: initializer.h:115
virtual void InitDefault(NDArray *arr)
Definition: initializer.h:110
float magnitude
Definition: initializer.h:201
virtual void operator()(const std::string &name, NDArray *arr)
Definition: initializer.h:51
static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:206
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:117
Definition: initializer.h:38
virtual void InitQuantizedBias(NDArray *arr)
Definition: initializer.h:107