27 #ifndef MXNET_CPP_METRIC_H_ 28 #define MXNET_CPP_METRIC_H_ 35 #include "dmlc/logging.h" 59 bool strict =
false) {
63 CHECK_EQ(labels.
Size(), preds.
Size());
73 CHECK_EQ(labels.
GetShape().size(), 1);
75 std::vector<mx_float> pred_data(len);
76 std::vector<mx_float> label_data(len);
79 for (
mx_uint i = 0; i < len; ++i) {
80 sum_metric += (pred_data[i] == label_data[i]) ? 1 : 0;
91 static const float epsilon = 1e-15;
94 std::vector<mx_float> pred_data(len * m);
95 std::vector<mx_float> label_data(len);
98 for (
mx_uint i = 0; i < len; ++i) {
113 std::vector<mx_float> pred_data;
115 std::vector<mx_float> label_data;
118 size_t len = preds.
Size();
120 for (
size_t i = 0; i < len; ++i) {
121 sum +=
std::abs(pred_data[i] - label_data[i]);
135 std::vector<mx_float> pred_data;
137 std::vector<mx_float> label_data;
140 size_t len = preds.
Size();
142 for (
size_t i = 0; i < len; ++i) {
143 mx_float diff = pred_data[i] - label_data[i];
158 std::vector<mx_float> pred_data;
160 std::vector<mx_float> label_data;
163 size_t len = preds.
Size();
165 for (
size_t i = 0; i < len; ++i) {
166 mx_float diff = pred_data[i] - label_data[i];
182 std::vector<mx_float> pred_data;
184 std::vector<mx_float> label_data;
187 size_t len = preds.
Size();
189 for (
size_t i = 0; i < len; ++i) {
190 mx_float diff = pred_data[i] - label_data[i];
209 #endif // MXNET_CPP_METRIC_H_
Accuracy()
Definition: metric.h:70
float Get()
Definition: metric.h:49
namespace of mxnet
Definition: base.h:127
int num
Definition: metric.h:54
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1680
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2182
EvalMetric(const std::string &name, int num=0)
Definition: metric.h:42
MAE()
Definition: metric.h:108
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:110
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:72
void SyncCopyToCPU(mx_float *data, size_t size=0)
Do a synchronize copy to a continugous CPU memory region.
NDArray interface.
Definition: ndarray.h:121
float sum_metric
Definition: metric.h:55
virtual void Update(NDArray labels, NDArray preds)=0
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:1801
LogLoss()
Definition: metric.h:88
int num_inst
Definition: metric.h:56
RMSE()
Definition: metric.h:153
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:58
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1427
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:90
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:179
float mx_float
manually define float
Definition: c_api.h:60
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:132
std::string name
Definition: metric.h:53
static void CheckLabelShapes(NDArray labels, NDArray preds, bool strict=false)
Definition: metric.h:58
std::vector< mx_uint > GetShape() const
void Reset()
Definition: metric.h:45
PSNR()
Definition: metric.h:176
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:155
MSE()
Definition: metric.h:130
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:1993