mxnet
monitor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_MONITOR_H_
28 #define MXNET_CPP_MONITOR_H_
29 
30 #include <regex>
31 #include <tuple>
32 #include <vector>
33 #include <map>
34 #include <set>
35 #include <string>
36 #include <functional>
37 #include "mxnet-cpp/base.h"
38 #include "mxnet-cpp/ndarray.h"
39 #include "mxnet-cpp/executor.h"
40 
41 namespace mxnet {
42 namespace cpp {
43 
50 NDArray _default_monitor_func(const NDArray &x);
51 
55 class Monitor {
56  public:
57  typedef std::function<NDArray(const NDArray&)> StatFunc;
58  typedef std::tuple<int, std::string, NDArray> Stat;
59 
67  Monitor(int interval, std::regex pattern = std::regex(".*"),
68  StatFunc stat_func = _default_monitor_func);
69 
74  void install(Executor *exe);
75 
79  void tic();
80 
85  std::vector<Stat> toc();
86 
90  void toc_print();
91 
92  protected:
93  int interval;
94  std::regex pattern;
95  StatFunc stat_func;
96  std::vector<Executor*> exes;
97 
98  int step;
99  bool activated;
100  std::vector<Stat> stats;
101 
102  static void executor_callback(const char *name, NDArrayHandle ndarray, void *monitor_ptr);
103 };
104 
105 } // namespace cpp
106 } // namespace mxnet
107 #endif // MXNET_CPP_MONITOR_H_
std::regex pattern
Definition: monitor.h:94
void toc_print()
End collecting and print results.
int interval
Definition: monitor.h:93
std::function< NDArray(const NDArray &)> StatFunc
Definition: monitor.h:57
namespace of mxnet
Definition: base.h:127
Executor interface.
Definition: executor.h:45
Monitor(int interval, std::regex pattern=std::regex(".*"), StatFunc stat_func=_default_monitor_func)
Monitor constructor.
NDArray _default_monitor_func(const NDArray &x)
Default function for monitor that computes statistics of the input tensor, which is the mean absolute...
void install(Executor *exe)
install callback to executor. Supports installing to multiple executors.
bool activated
Definition: monitor.h:99
void tic()
Start collecting stats for current batch. Call before calling forward.
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:65
int step
Definition: monitor.h:98
std::vector< Stat > stats
Definition: monitor.h:100
Monitor interface.
Definition: monitor.h:55
StatFunc stat_func
Definition: monitor.h:95
std::vector< Executor * > exes
Definition: monitor.h:96
std::tuple< int, std::string, NDArray > Stat
Definition: monitor.h:58
std::vector< Stat > toc()
End collecting for current batch and return results. Call after computation of current batch...
static void executor_callback(const char *name, NDArrayHandle ndarray, void *monitor_ptr)