mxnet
initializer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_INITIALIZER_H_
28 #define MXNET_CPP_INITIALIZER_H_
29 
30 #include <cmath>
31 #include <string>
32 #include <vector>
33 #include "mxnet-cpp/ndarray.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 
38 class Initializer {
39  public:
40  static bool StringStartWith(const std::string& name,
41  const std::string& check_str) {
42  return (name.size() >= check_str.size() &&
43  name.substr(0, check_str.size()) == check_str);
44  }
45  static bool StringEndWith(const std::string& name,
46  const std::string& check_str) {
47  return (name.size() >= check_str.size() &&
48  name.substr(name.size() - check_str.size(), check_str.size()) ==
49  check_str);
50  }
51  virtual void operator()(const std::string& name, NDArray* arr) {
52  if (StringStartWith(name, "upsampling")) {
53  InitBilinear(arr);
54  } else if (StringEndWith(name, "bias")) {
55  InitBias(arr);
56  } else if (StringEndWith(name, "gamma")) {
57  InitGamma(arr);
58  } else if (StringEndWith(name, "beta")) {
59  InitBeta(arr);
60  } else if (StringEndWith(name, "weight")) {
61  InitWeight(arr);
62  } else if (StringEndWith(name, "moving_mean")) {
63  InitZero(arr);
64  } else if (StringEndWith(name, "moving_var")) {
65  InitOne(arr);
66  } else if (StringEndWith(name, "moving_inv_var")) {
67  InitZero(arr);
68  } else if (StringEndWith(name, "moving_avg")) {
69  InitZero(arr);
70  } else {
71  InitDefault(arr);
72  }
73  }
74 
75  protected:
76  virtual void InitBilinear(NDArray* arr) {
77  Shape shape(arr->GetShape());
78  std::vector<float> weight(shape.Size(), 0);
79  int f = std::ceil(shape[3] / 2.0);
80  float c = (2 * f - 1 - f % 2) / (2. * f);
81  for (size_t i = 0; i < shape.Size(); ++i) {
82  int x = i % shape[3];
83  int y = (i / shape[3]) % shape[2];
84  weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
85  }
86  (*arr).SyncCopyFromCPU(weight);
87  }
88  virtual void InitZero(NDArray* arr) { (*arr) = 0.0f; }
89  virtual void InitOne(NDArray* arr) { (*arr) = 1.0f; }
90  virtual void InitBias(NDArray* arr) { (*arr) = 0.0f; }
91  virtual void InitGamma(NDArray* arr) { (*arr) = 1.0f; }
92  virtual void InitBeta(NDArray* arr) { (*arr) = 0.0f; }
93  virtual void InitWeight(NDArray* arr) {}
94  virtual void InitDefault(NDArray* arr) {}
95 };
96 
97 class Constant : public Initializer {
98  public:
99  explicit Constant(float value)
100  : value(value) {}
101  void operator()(const std::string &name, NDArray *arr) override {
102  (*arr) = value;
103  }
104  protected:
105  float value;
106 };
107 
108 class Zero : public Constant {
109  public:
110  Zero(): Constant(0.0f) {}
111 };
112 
113 class One : public Constant {
114  public:
115  One(): Constant(1.0f) {}
116 };
117 
118 class Uniform : public Initializer {
119  public:
120  explicit Uniform(float scale)
121  : Uniform(-scale, scale) {}
122  Uniform(float begin, float end)
123  : begin(begin), end(end) {}
124  void operator()(const std::string &name, NDArray *arr) override {
125  NDArray::SampleUniform(begin, end, arr);
126  }
127  protected:
128  float begin, end;
129 };
130 
131 class Normal : public Initializer {
132  public:
133  Normal(float mu, float sigma)
134  : mu(mu), sigma(sigma) {}
135  void operator()(const std::string &name, NDArray *arr) override {
136  NDArray::SampleGaussian(mu, sigma, arr);
137  }
138  protected:
139  float mu, sigma;
140 };
141 
142 class Bilinear : public Initializer {
143  public:
144  Bilinear() {}
145  void operator()(const std::string &name, NDArray *arr) override {
146  InitBilinear(arr);
147  }
148 };
149 
150 class Xavier : public Initializer {
151  public:
152  enum RandType {
154  uniform
155  } rand_type;
156  enum FactorType {
158  in,
159  out
160  } factor_type;
161  float magnitude;
162  Xavier(RandType rand_type = gaussian, FactorType factor_type = avg,
163  float magnitude = 3)
164  : rand_type(rand_type), factor_type(factor_type), magnitude(magnitude) {}
165 
166  void operator()(const std::string &name, NDArray* arr) override {
167  Shape shape(arr->GetShape());
168  float hw_scale = 1.0f;
169  if (shape.ndim() > 2) {
170  for (size_t i = 2; i < shape.ndim(); ++i) {
171  hw_scale *= shape[i];
172  }
173  }
174  float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale;
175  float factor = 1.0f;
176  switch (factor_type) {
177  case avg:
178  factor = (fan_in + fan_out) / 2.0;
179  break;
180  case in:
181  factor = fan_in;
182  break;
183  case out:
184  factor = fan_out;
185  }
186  float scale = std::sqrt(magnitude / factor);
187  switch (rand_type) {
188  case uniform:
189  NDArray::SampleUniform(-scale, scale, arr);
190  break;
191  case gaussian:
192  NDArray::SampleGaussian(0, scale, arr);
193  break;
194  }
195  }
196 };
197 
198 class MSRAPrelu : public Xavier {
199  public:
200  explicit MSRAPrelu(FactorType factor_type = avg, float slope = 0.25f)
201  : Xavier(gaussian, factor_type, 2. / (1 + slope * slope)) {}
202 };
203 
204 } // namespace cpp
205 } // namespace mxnet
206 
207 #endif // MXNET_CPP_INITIALIZER_H_
static bool StringStartWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:40
Uniform(float scale)
Definition: initializer.h:120
static bool StringEndWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:45
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:145
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:135
Definition: initializer.h:142
Definition: initializer.h:150
namespace of mxnet
Definition: base.h:127
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Definition: initializer.h:97
MSRAPrelu(FactorType factor_type=avg, float slope=0.25f)
Definition: initializer.h:200
Definition: initializer.h:118
FactorType
Definition: initializer.h:156
Xavier(RandType rand_type=gaussian, FactorType factor_type=avg, float magnitude=3)
Definition: initializer.h:162
float value
Definition: initializer.h:105
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1751
Uniform(float begin, float end)
Definition: initializer.h:122
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:124
virtual void InitOne(NDArray *arr)
Definition: initializer.h:89
virtual void InitGamma(NDArray *arr)
Definition: initializer.h:91
Definition: initializer.h:113
virtual void InitBias(NDArray *arr)
Definition: initializer.h:90
One()
Definition: initializer.h:115
virtual void InitZero(NDArray *arr)
Definition: initializer.h:88
NDArray interface.
Definition: ndarray.h:121
Definition: initializer.h:153
Definition: initializer.h:108
Definition: initializer.h:157
float sigma
Definition: initializer.h:139
virtual void InitBeta(NDArray *arr)
Definition: initializer.h:92
Definition: initializer.h:158
virtual void InitWeight(NDArray *arr)
Definition: initializer.h:93
Definition: initializer.h:131
Definition: initializer.h:198
float end
Definition: initializer.h:128
RandType
Definition: initializer.h:152
Zero()
Definition: initializer.h:110
static void SampleUniform(mx_float begin, mx_float end, NDArray *out)
Sample uniform distribution for each elements of out.
Bilinear()
Definition: initializer.h:144
virtual void InitBilinear(NDArray *arr)
Definition: initializer.h:76
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1498
Normal(float mu, float sigma)
Definition: initializer.h:133
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:1608
Constant(float value)
Definition: initializer.h:99
virtual void InitDefault(NDArray *arr)
Definition: initializer.h:94
float magnitude
Definition: initializer.h:161
std::vector< mx_uint > GetShape() const
virtual void operator()(const std::string &name, NDArray *arr)
Definition: initializer.h:51
static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:166
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:101
Definition: initializer.h:38