25 #ifndef MXNET_ENGINE_H_    26 #define MXNET_ENGINE_H_    48   virtual ~Var() = 
default;
    77  inline void operator()(
const dmlc::Error* error = 
nullptr)
 const {
    78     (*callback_)(engine_, param_, error);
    83   friend class ::mxnet::Engine;
    85   void (*callback_)(
Engine *, 
void *, 
const dmlc::Error *);
   120   typedef std::function<void(RunContext)> 
SyncFn;
   122   typedef std::function<void(RunContext, CallbackOnComplete)> 
AsyncFn;
   134   virtual void NotifyShutdown() = 0;
   139     LOG(FATAL) << 
"Engine cannot be stopped";
   145     LOG(FATAL) << 
"Engine cannot be restarted";
   153   virtual VarHandle NewVariable() = 0;
   166   virtual OprHandle NewOperator(AsyncFn fn,
   167                                 std::vector<VarHandle> 
const& const_vars,
   168                                 std::vector<VarHandle> 
const& mutable_vars,
   170                                 const char* opr_name = 
nullptr,
   171                                 bool wait = 
false) = 0;
   179   virtual void DeleteOperator(OprHandle op) = 0;
   187   virtual void Push(OprHandle op, 
Context exec_ctx, 
int priority = 0, 
bool profiling = 
false) = 0;
   202   virtual void PushAsync(AsyncFn exec_fun, 
Context exec_ctx,
   203                          std::vector<VarHandle> 
const& const_vars,
   204                          std::vector<VarHandle> 
const& mutable_vars,
   207                          const char* opr_name = 
nullptr,
   208                          bool wait = 
false) = 0;
   220   virtual void DeleteVariable(SyncFn delete_fn,
   228   virtual void WaitForVar(VarHandle var) = 0;
   232   virtual void WaitForAll() = 0;
   247   static std::shared_ptr<Engine> _GetSharedRef();
   261                         std::vector<VarHandle> 
const& const_vars,
   262                         std::vector<VarHandle> 
const& mutable_vars,
   265                         const char* opr_name = 
nullptr) {
   266     this->PushAsync([exec_fn](
RunContext ctx, CallbackOnComplete on_complete) {
   269       }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
   278       void (*callback)(
Engine *, 
void *, 
const dmlc::Error *), 
void *param) {
   279     CallbackOnComplete ret;
   280     ret.callback_ = callback;
   288                                    std::vector<engine::VarHandle> *write_vars) {
   289     std::sort(write_vars->begin(), write_vars->end());
   290     write_vars->resize(std::unique(write_vars->begin(), write_vars->end()) -
   291                       write_vars->begin());
   292     std::sort(read_vars->begin(), read_vars->end());
   293     read_vars->resize(std::unique(read_vars->begin(), read_vars->end()) -
   295     auto wit = write_vars->begin();
   296     auto rtop = read_vars->begin();
   297     for (
auto rit = read_vars->begin(); rit != read_vars->end(); ++rit) {
   298       while (wit != write_vars->end() && *wit < *rit) ++wit;
   299       if (wit == write_vars->end() || *wit != *rit) {
   304     read_vars->resize(rtop - read_vars->begin());
   315 #endif  // DMLC_USE_CXX11   317 #endif  // MXNET_ENGINE_H_ void DeduplicateVarHandle(std::vector< engine::VarHandle > *read_vars, std::vector< engine::VarHandle > *write_vars)
Definition: engine.h:287
 
FnProperty
Function property, used to hint what action is pushed to engine. 
Definition: engine.h:95
 
virtual void Stop()
Stop all workers in the engine. 
Definition: engine.h:138
 
std::function< void(RunContext)> SyncFn
Synchronous operation to pass to engine. 
Definition: engine.h:120
 
std::function< void(RunContext, CallbackOnComplete)> AsyncFn
Asynchronous operation to pass to engine. 
Definition: engine.h:122
 
size_t version_
version number of the var. Every time the object it is associated with is modified, the version number is incremented by 1. 
Definition: engine.h:60
 
namespace of mxnet 
Definition: base.h:118
 
virtual int bulk_size() const 
query current limit for bulk size 
Definition: engine.h:307
 
Asynchronous function call. 
 
execution time context. The information needed in runtime for actual execution. 
Definition: base.h:257
 
base class of engine variables. 
Definition: engine.h:44
 
virtual ~Engine() noexcept(false)
virtual destructor 
Definition: engine.h:234
 
CallbackOnComplete CreateCallback(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)
factory function to create OnComplete callback. 
Definition: engine.h:277
 
Prioritized sync operation on GPU. 
 
void operator()(const dmlc::Error *error=nullptr) const 
involve the callback 
Definition: engine.h:77
 
virtual void Start()
Restart all workers in the engine. 
Definition: engine.h:144
 
Copy operation from GPU to other devices. 
 
virtual int set_bulk_size(int)
set maximum limit for bulk size 
Definition: engine.h:311
 
engine::OprHandle OprHandle
Operator pointer. 
Definition: engine.h:126
 
engine::VarHandle VarHandle
Variable pointer. 
Definition: engine.h:124
 
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies. 
Definition: engine.h:64
 
Prioritized sync operation on CPU. 
 
engine::CallbackOnComplete CallbackOnComplete
callback on complete 
Definition: engine.h:118
 
virtual void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr)
Push an synchronous operation to the engine. 
Definition: engine.h:260
 
Dependency engine that schedules operations. 
Definition: engine.h:115
 
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:3107
 
OnComplete Callback to the engine, called by AsyncFn when action completes. 
Definition: engine.h:73
 
T * Cast()
cast variable to derived type T 
 
Context information about the execution environment. 
Definition: base.h:133
 
#define MXNET_API
define compatible keywords in g++ Used to support g++-4.6 and g++4.7 
Definition: base.h:92
 
Copy operation from CPU to other devices. 
 
Opr * OprHandle
Operator pointer type, usually hold by user. 
Definition: engine.h:68
 
virtual size_t version()
Definition: engine.h:45