26 #ifndef MXNET_CPP_METRIC_H_ 27 #define MXNET_CPP_METRIC_H_ 34 #include "dmlc/logging.h" 58 bool strict =
false) {
62 CHECK_EQ(labels.
Size(), preds.
Size());
72 CHECK_EQ(labels.
GetShape().size(), 1);
74 std::vector<mx_float> pred_data(len);
75 std::vector<mx_float> label_data(len);
78 for (
mx_uint i = 0; i < len; ++i) {
79 sum_metric += (pred_data[i] == label_data[i]) ? 1 : 0;
90 static const float epsilon = 1e-15;
93 std::vector<mx_float> pred_data(len * m);
94 std::vector<mx_float> label_data(len);
97 for (
mx_uint i = 0; i < len; ++i) {
112 std::vector<mx_float> pred_data;
114 std::vector<mx_float> label_data;
117 size_t len = preds.
Size();
119 for (
size_t i = 0; i < len; ++i) {
120 sum +=
std::abs(pred_data[i] - label_data[i]);
134 std::vector<mx_float> pred_data;
136 std::vector<mx_float> label_data;
139 size_t len = preds.
Size();
141 for (
size_t i = 0; i < len; ++i) {
142 mx_float diff = pred_data[i] - label_data[i];
157 std::vector<mx_float> pred_data;
159 std::vector<mx_float> label_data;
162 size_t len = preds.
Size();
164 for (
size_t i = 0; i < len; ++i) {
165 mx_float diff = pred_data[i] - label_data[i];
181 std::vector<mx_float> pred_data;
183 std::vector<mx_float> label_data;
186 size_t len = preds.
Size();
188 for (
size_t i = 0; i < len; ++i) {
189 mx_float diff = pred_data[i] - label_data[i];
208 #endif // MXNET_CPP_METRIC_H_
Accuracy()
Definition: metric.h:69
float Get()
Definition: metric.h:48
namespace of mxnet
Definition: base.h:126
int num
Definition: metric.h:53
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:3007
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2282
EvalMetric(const std::string &name, int num=0)
Definition: metric.h:41
MAE()
Definition: metric.h:107
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:109
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:71
void SyncCopyToCPU(mx_float *data, size_t size=0)
Do a synchronize copy to a continugous CPU memory region.
NDArray interface.
Definition: ndarray.h:120
float sum_metric
Definition: metric.h:54
virtual void Update(NDArray labels, NDArray preds)=0
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:3074
LogLoss()
Definition: metric.h:87
int num_inst
Definition: metric.h:55
RMSE()
Definition: metric.h:152
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:57
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:2801
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:89
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:178
float mx_float
manually define float
Definition: c_api.h:59
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:131
std::string name
Definition: metric.h:52
static void CheckLabelShapes(NDArray labels, NDArray preds, bool strict=false)
Definition: metric.h:57
std::vector< mx_uint > GetShape() const
void Reset()
Definition: metric.h:44
PSNR()
Definition: metric.h:175
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:154
MSE()
Definition: metric.h:129
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2093