mxnet
metric.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_METRIC_H_
28 #define MXNET_CPP_METRIC_H_
29 
30 #include <cmath>
31 #include <string>
32 #include <vector>
33 #include <algorithm>
34 #include "mxnet-cpp/ndarray.h"
35 #include "dmlc/logging.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 
40 class EvalMetric {
41  public:
42  explicit EvalMetric(const std::string& name, int num = 0)
43  : name(name), num(num) {}
44  virtual void Update(NDArray labels, NDArray preds) = 0;
45  void Reset() {
46  num_inst = 0;
47  sum_metric = 0.0f;
48  }
49  float Get() { return sum_metric / num_inst; }
50  void GetNameValue();
51 
52  protected:
53  std::string name;
54  int num;
55  float sum_metric = 0.0f;
56  int num_inst = 0;
57 
58  static void CheckLabelShapes(NDArray labels, NDArray preds,
59  bool strict = false) {
60  if (strict) {
61  CHECK_EQ(Shape(labels.GetShape()), Shape(preds.GetShape()));
62  } else {
63  CHECK_EQ(labels.Size(), preds.Size());
64  }
65  }
66 };
67 
68 class Accuracy : public EvalMetric {
69  public:
70  Accuracy() : EvalMetric("accuracy") {}
71 
72  void Update(NDArray labels, NDArray preds) override {
73  CHECK_EQ(labels.GetShape().size(), 1);
74  mx_uint len = labels.GetShape()[0];
75  std::vector<mx_float> pred_data(len);
76  std::vector<mx_float> label_data(len);
77  preds.ArgmaxChannel().SyncCopyToCPU(&pred_data, len);
78  labels.SyncCopyToCPU(&label_data, len);
79  for (mx_uint i = 0; i < len; ++i) {
80  sum_metric += (pred_data[i] == label_data[i]) ? 1 : 0;
81  num_inst += 1;
82  }
83  }
84 };
85 
86 class LogLoss : public EvalMetric {
87  public:
88  LogLoss() : EvalMetric("logloss") {}
89 
90  void Update(NDArray labels, NDArray preds) override {
91  static const float epsilon = 1e-15;
92  mx_uint len = labels.GetShape()[0];
93  mx_uint m = preds.GetShape()[1];
94  std::vector<mx_float> pred_data(len * m);
95  std::vector<mx_float> label_data(len);
96  preds.SyncCopyToCPU(&pred_data, pred_data.size());
97  labels.SyncCopyToCPU(&label_data, len);
98  for (mx_uint i = 0; i < len; ++i) {
99  sum_metric +=
100  -std::log(std::max(pred_data[i * m + label_data[i]], epsilon));
101  num_inst += 1;
102  }
103  }
104 };
105 
106 class MAE : public EvalMetric {
107  public:
108  MAE() : EvalMetric("mae") {}
109 
110  void Update(NDArray labels, NDArray preds) override {
111  CheckLabelShapes(labels, preds);
112 
113  std::vector<mx_float> pred_data;
114  preds.SyncCopyToCPU(&pred_data);
115  std::vector<mx_float> label_data;
116  labels.SyncCopyToCPU(&label_data);
117 
118  size_t len = preds.Size();
119  mx_float sum = 0;
120  for (size_t i = 0; i < len; ++i) {
121  sum += std::abs(pred_data[i] - label_data[i]);
122  }
123  sum_metric += sum / len;
124  ++num_inst;
125  }
126 };
127 
128 class MSE : public EvalMetric {
129  public:
130  MSE() : EvalMetric("mse") {}
131 
132  void Update(NDArray labels, NDArray preds) override {
133  CheckLabelShapes(labels, preds);
134 
135  std::vector<mx_float> pred_data;
136  preds.SyncCopyToCPU(&pred_data);
137  std::vector<mx_float> label_data;
138  labels.SyncCopyToCPU(&label_data);
139 
140  size_t len = preds.Size();
141  mx_float sum = 0;
142  for (size_t i = 0; i < len; ++i) {
143  mx_float diff = pred_data[i] - label_data[i];
144  sum += diff * diff;
145  }
146  sum_metric += sum / len;
147  ++num_inst;
148  }
149 };
150 
151 class RMSE : public EvalMetric {
152  public:
153  RMSE() : EvalMetric("rmse") {}
154 
155  void Update(NDArray labels, NDArray preds) override {
156  CheckLabelShapes(labels, preds);
157 
158  std::vector<mx_float> pred_data;
159  preds.SyncCopyToCPU(&pred_data);
160  std::vector<mx_float> label_data;
161  labels.SyncCopyToCPU(&label_data);
162 
163  size_t len = preds.Size();
164  mx_float sum = 0;
165  for (size_t i = 0; i < len; ++i) {
166  mx_float diff = pred_data[i] - label_data[i];
167  sum += diff * diff;
168  }
169  sum_metric += std::sqrt(sum / len);
170  ++num_inst;
171  }
172 };
173 
174 class PSNR : public EvalMetric {
175  public:
176  PSNR() : EvalMetric("psnr") {
177  }
178 
179  void Update(NDArray labels, NDArray preds) override {
180  CheckLabelShapes(labels, preds);
181 
182  std::vector<mx_float> pred_data;
183  preds.SyncCopyToCPU(&pred_data);
184  std::vector<mx_float> label_data;
185  labels.SyncCopyToCPU(&label_data);
186 
187  size_t len = preds.Size();
188  mx_float sum = 0;
189  for (size_t i = 0; i < len; ++i) {
190  mx_float diff = pred_data[i] - label_data[i];
191  sum += diff * diff;
192  }
193  mx_float mse = sum / len;
194  if (mse > 0) {
195  sum_metric += 10 * std::log(255.0f / mse) / log10_;
196  } else {
197  sum_metric += 99.0f;
198  }
199  ++num_inst;
200  }
201 
202  private:
203  mx_float log10_ = std::log(10.0f);
204 };
205 
206 } // namespace cpp
207 } // namespace mxnet
208 
209 #endif // MXNET_CPP_METRIC_H_
210 
NDArray ArgmaxChannel()
Accuracy()
Definition: metric.h:70
float Get()
Definition: metric.h:49
namespace of mxnet
Definition: base.h:118
int num
Definition: metric.h:54
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2207
EvalMetric(const std::string &name, int num=0)
Definition: metric.h:42
Definition: metric.h:68
Definition: metric.h:174
MAE()
Definition: metric.h:108
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:110
Symbol max(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2756
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:72
void SyncCopyToCPU(mx_float *data, size_t size=0)
Do a synchronize copy to a continugous CPU memory region.
NDArray interface.
Definition: ndarray.h:121
float sum_metric
Definition: metric.h:55
virtual void Update(NDArray labels, NDArray preds)=0
Definition: metric.h:40
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:2355
size_t Size() const
LogLoss()
Definition: metric.h:88
Definition: metric.h:151
int num_inst
Definition: metric.h:56
RMSE()
Definition: metric.h:153
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:58
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1946
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:90
Definition: metric.h:86
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:179
float mx_float
manually define float
Definition: c_api.h:60
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:132
std::string name
Definition: metric.h:53
static void CheckLabelShapes(NDArray labels, NDArray preds, bool strict=false)
Definition: metric.h:58
Definition: metric.h:128
std::vector< mx_uint > GetShape() const
void Reset()
Definition: metric.h:45
PSNR()
Definition: metric.h:176
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:155
Definition: metric.h:106
Symbol sum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2567
MSE()
Definition: metric.h:130