mxnet
metric.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_METRIC_H_
27 #define MXNET_CPP_METRIC_H_
28 
29 #include <cmath>
30 #include <string>
31 #include <vector>
32 #include <algorithm>
33 #include "mxnet-cpp/ndarray.h"
34 #include "dmlc/logging.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class EvalMetric {
40  public:
41  explicit EvalMetric(const std::string& name, int num = 0)
42  : name(name), num(num) {}
43  virtual void Update(NDArray labels, NDArray preds) = 0;
44  void Reset() {
45  num_inst = 0;
46  sum_metric = 0.0f;
47  }
48  float Get() { return sum_metric / num_inst; }
49  void GetNameValue();
50 
51  protected:
52  std::string name;
53  int num;
54  float sum_metric = 0.0f;
55  int num_inst = 0;
56 
57  static void CheckLabelShapes(NDArray labels, NDArray preds,
58  bool strict = false) {
59  if (strict) {
60  CHECK_EQ(Shape(labels.GetShape()), Shape(preds.GetShape()));
61  } else {
62  CHECK_EQ(labels.Size(), preds.Size());
63  }
64  }
65 };
66 
67 class Accuracy : public EvalMetric {
68  public:
69  Accuracy() : EvalMetric("accuracy") {}
70 
71  void Update(NDArray labels, NDArray preds) override {
72  CHECK_EQ(labels.GetShape().size(), 1);
73  mx_uint len = labels.GetShape()[0];
74  std::vector<mx_float> pred_data(len);
75  std::vector<mx_float> label_data(len);
76  preds.ArgmaxChannel().SyncCopyToCPU(&pred_data, len);
77  labels.SyncCopyToCPU(&label_data, len);
78  for (mx_uint i = 0; i < len; ++i) {
79  sum_metric += (pred_data[i] == label_data[i]) ? 1 : 0;
80  num_inst += 1;
81  }
82  }
83 };
84 
85 class LogLoss : public EvalMetric {
86  public:
87  LogLoss() : EvalMetric("logloss") {}
88 
89  void Update(NDArray labels, NDArray preds) override {
90  static const float epsilon = 1e-15;
91  mx_uint len = labels.GetShape()[0];
92  mx_uint m = preds.GetShape()[1];
93  std::vector<mx_float> pred_data(len * m);
94  std::vector<mx_float> label_data(len);
95  preds.SyncCopyToCPU(&pred_data, pred_data.size());
96  labels.SyncCopyToCPU(&label_data, len);
97  for (mx_uint i = 0; i < len; ++i) {
98  sum_metric +=
99  -std::log(std::max(pred_data[i * m + label_data[i]], epsilon));
100  num_inst += 1;
101  }
102  }
103 };
104 
105 class MAE : public EvalMetric {
106  public:
107  MAE() : EvalMetric("mae") {}
108 
109  void Update(NDArray labels, NDArray preds) override {
110  CheckLabelShapes(labels, preds);
111 
112  std::vector<mx_float> pred_data;
113  preds.SyncCopyToCPU(&pred_data);
114  std::vector<mx_float> label_data;
115  labels.SyncCopyToCPU(&label_data);
116 
117  size_t len = preds.Size();
118  mx_float sum = 0;
119  for (size_t i = 0; i < len; ++i) {
120  sum += std::abs(pred_data[i] - label_data[i]);
121  }
122  sum_metric += sum / len;
123  ++num_inst;
124  }
125 };
126 
127 class MSE : public EvalMetric {
128  public:
129  MSE() : EvalMetric("mse") {}
130 
131  void Update(NDArray labels, NDArray preds) override {
132  CheckLabelShapes(labels, preds);
133 
134  std::vector<mx_float> pred_data;
135  preds.SyncCopyToCPU(&pred_data);
136  std::vector<mx_float> label_data;
137  labels.SyncCopyToCPU(&label_data);
138 
139  size_t len = preds.Size();
140  mx_float sum = 0;
141  for (size_t i = 0; i < len; ++i) {
142  mx_float diff = pred_data[i] - label_data[i];
143  sum += diff * diff;
144  }
145  sum_metric += sum / len;
146  ++num_inst;
147  }
148 };
149 
150 class RMSE : public EvalMetric {
151  public:
152  RMSE() : EvalMetric("rmse") {}
153 
154  void Update(NDArray labels, NDArray preds) override {
155  CheckLabelShapes(labels, preds);
156 
157  std::vector<mx_float> pred_data;
158  preds.SyncCopyToCPU(&pred_data);
159  std::vector<mx_float> label_data;
160  labels.SyncCopyToCPU(&label_data);
161 
162  size_t len = preds.Size();
163  mx_float sum = 0;
164  for (size_t i = 0; i < len; ++i) {
165  mx_float diff = pred_data[i] - label_data[i];
166  sum += diff * diff;
167  }
168  sum_metric += std::sqrt(sum / len);
169  ++num_inst;
170  }
171 };
172 
173 class PSNR : public EvalMetric {
174  public:
175  PSNR() : EvalMetric("psnr") {
176  }
177 
178  void Update(NDArray labels, NDArray preds) override {
179  CheckLabelShapes(labels, preds);
180 
181  std::vector<mx_float> pred_data;
182  preds.SyncCopyToCPU(&pred_data);
183  std::vector<mx_float> label_data;
184  labels.SyncCopyToCPU(&label_data);
185 
186  size_t len = preds.Size();
187  mx_float sum = 0;
188  for (size_t i = 0; i < len; ++i) {
189  mx_float diff = pred_data[i] - label_data[i];
190  sum += diff * diff;
191  }
192  mx_float mse = sum / len;
193  if (mse > 0) {
194  sum_metric += 10 * std::log(255.0f / mse) / log10_;
195  } else {
196  sum_metric += 99.0f;
197  }
198  ++num_inst;
199  }
200 
201  private:
202  mx_float log10_ = std::log(10.0f);
203 };
204 
205 } // namespace cpp
206 } // namespace mxnet
207 
208 #endif // MXNET_CPP_METRIC_H_
209 
NDArray ArgmaxChannel()
Accuracy()
Definition: metric.h:69
float Get()
Definition: metric.h:48
namespace of mxnet
Definition: base.h:126
int num
Definition: metric.h:53
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:3007
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2282
EvalMetric(const std::string &name, int num=0)
Definition: metric.h:41
Definition: metric.h:67
Definition: metric.h:173
MAE()
Definition: metric.h:107
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:109
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:71
void SyncCopyToCPU(mx_float *data, size_t size=0)
Do a synchronize copy to a continugous CPU memory region.
NDArray interface.
Definition: ndarray.h:120
float sum_metric
Definition: metric.h:54
virtual void Update(NDArray labels, NDArray preds)=0
Definition: metric.h:39
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:3074
size_t Size() const
LogLoss()
Definition: metric.h:87
Definition: metric.h:150
int num_inst
Definition: metric.h:55
RMSE()
Definition: metric.h:152
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:57
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:2801
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:89
Definition: metric.h:85
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:178
float mx_float
manually define float
Definition: c_api.h:59
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:131
std::string name
Definition: metric.h:52
static void CheckLabelShapes(NDArray labels, NDArray preds, bool strict=false)
Definition: metric.h:57
Definition: metric.h:127
std::vector< mx_uint > GetShape() const
void Reset()
Definition: metric.h:44
PSNR()
Definition: metric.h:175
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:154
Definition: metric.h:105
MSE()
Definition: metric.h:129
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2093