25 #ifndef MXNET_ENGINE_H_ 26 #define MXNET_ENGINE_H_ 48 virtual ~Var() =
default;
77 inline void operator()(
const dmlc::Error* error =
nullptr)
const {
78 (*callback_)(engine_, param_, error);
83 friend class ::mxnet::Engine;
85 void (*callback_)(
Engine *,
void *,
const dmlc::Error *);
122 typedef std::function<void(RunContext)>
SyncFn;
124 typedef std::function<void(RunContext, CallbackOnComplete)>
AsyncFn;
136 virtual void NotifyShutdown() = 0;
141 LOG(FATAL) <<
"Engine cannot be stopped";
147 LOG(FATAL) <<
"Engine cannot be restarted";
155 virtual VarHandle NewVariable() = 0;
168 virtual OprHandle NewOperator(AsyncFn fn,
169 std::vector<VarHandle>
const& const_vars,
170 std::vector<VarHandle>
const& mutable_vars,
172 const char* opr_name =
nullptr,
173 bool wait =
false) = 0;
181 virtual void DeleteOperator(OprHandle op) = 0;
189 virtual void Push(OprHandle op,
Context exec_ctx,
int priority = 0,
bool profiling =
false) = 0;
204 virtual void PushAsync(AsyncFn exec_fun,
Context exec_ctx,
205 std::vector<VarHandle>
const& const_vars,
206 std::vector<VarHandle>
const& mutable_vars,
209 const char* opr_name =
nullptr,
210 bool wait =
false) = 0;
222 virtual void DeleteVariable(SyncFn delete_fn,
230 virtual void WaitForVar(VarHandle var) = 0;
234 virtual void WaitForAll() = 0;
236 virtual void Throw(VarHandle var) = 0;
251 static std::shared_ptr<Engine> _GetSharedRef();
265 std::vector<VarHandle>
const& const_vars,
266 std::vector<VarHandle>
const& mutable_vars,
269 const char* opr_name =
nullptr) {
270 this->PushAsync([exec_fn](
RunContext ctx, CallbackOnComplete on_complete) {
273 }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
282 void (*callback)(
Engine *,
void *,
const dmlc::Error *),
void *param) {
283 CallbackOnComplete ret;
284 ret.callback_ = callback;
292 std::vector<engine::VarHandle> *write_vars) {
293 std::sort(write_vars->begin(), write_vars->end());
294 write_vars->resize(std::unique(write_vars->begin(), write_vars->end()) -
295 write_vars->begin());
296 std::sort(read_vars->begin(), read_vars->end());
297 read_vars->resize(std::unique(read_vars->begin(), read_vars->end()) -
299 auto wit = write_vars->begin();
300 auto rtop = read_vars->begin();
301 for (
auto rit = read_vars->begin(); rit != read_vars->end(); ++rit) {
302 while (wit != write_vars->end() && *wit < *rit) ++wit;
303 if (wit == write_vars->end() || *wit != *rit) {
308 read_vars->resize(rtop - read_vars->begin());
319 #endif // DMLC_USE_CXX11 321 #endif // MXNET_ENGINE_H_ void DeduplicateVarHandle(std::vector< engine::VarHandle > *read_vars, std::vector< engine::VarHandle > *write_vars)
Definition: engine.h:291
FnProperty
Function property, used to hint what action is pushed to engine.
Definition: engine.h:95
virtual void Stop()
Stop all workers in the engine.
Definition: engine.h:140
std::function< void(RunContext)> SyncFn
Synchronous operation to pass to engine.
Definition: engine.h:122
std::function< void(RunContext, CallbackOnComplete)> AsyncFn
Asynchronous operation to pass to engine.
Definition: engine.h:124
Operation not to be skipped even with associated exception.
size_t version_
version number of the var. Every time the object it is associated with is modified, the version number is incremented by 1.
Definition: engine.h:60
namespace of mxnet
Definition: base.h:89
virtual int bulk_size() const
query current limit for bulk size
Definition: engine.h:311
Asynchronous function call.
execution time context. The information needed in runtime for actual execution.
Definition: base.h:337
base class of engine variables.
Definition: engine.h:44
virtual ~Engine() noexcept(false)
virtual destructor
Definition: engine.h:238
CallbackOnComplete CreateCallback(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)
factory function to create OnComplete callback.
Definition: engine.h:281
Prioritized sync operation on GPU.
void operator()(const dmlc::Error *error=nullptr) const
involve the callback
Definition: engine.h:77
virtual void Start()
Restart all workers in the engine.
Definition: engine.h:146
Copy operation from GPU to other devices.
virtual int set_bulk_size(int)
set maximum limit for bulk size
Definition: engine.h:315
engine::OprHandle OprHandle
Operator pointer.
Definition: engine.h:128
engine::VarHandle VarHandle
Variable pointer.
Definition: engine.h:126
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies.
Definition: engine.h:64
Prioritized sync operation on CPU.
engine::CallbackOnComplete CallbackOnComplete
callback on complete
Definition: engine.h:120
virtual void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr)
Push an synchronous operation to the engine.
Definition: engine.h:264
Dependency engine that schedules operations.
Definition: engine.h:117
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Returns a sorted copy of an input array along the given axis.
Definition: op.h:2763
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:73
T * Cast()
cast variable to derived type T
Context information about the execution environment.
Definition: base.h:102
#define MXNET_API
define compatible keywords in g++ Used to support g++-4.6 and g++4.7
Definition: base.h:63
Copy operation from CPU to other devices.
Opr * OprHandle
Operator pointer type, usually hold by user.
Definition: engine.h:68
virtual size_t version()
Definition: engine.h:45