mxnet
engine.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_ENGINE_H_
25 #define MXNET_ENGINE_H_
26 
27 #if DMLC_USE_CXX11
28 #include <algorithm>
29 #include <memory>
30 #include <functional>
31 #endif
32 #include <vector>
33 #include "./base.h"
34 
35 namespace mxnet {
36 
37 // forward declare engine
38 class Engine;
39 
41 namespace engine {
43 struct Var {
44  virtual size_t version() {
45  return version_;
46  }
47  virtual ~Var() = default;
53  template <typename T>
54  inline T* Cast();
59  size_t version_{0};
60 }; // struct Var
61 
63 struct Opr;
65 typedef Var* VarHandle;
67 typedef Opr* OprHandle;
73  public:
74  // use implicit copy and assign
76  inline void operator()(const dmlc::Error* error = nullptr) const {
77  (*callback_)(engine_, param_, error);
78  }
79 
80  private:
82  friend class ::mxnet::Engine;
84  void (*callback_)(Engine *, void *, const dmlc::Error *);
86  Engine* engine_;
88  void* param_;
89 };
90 } // namespace engine
91 
92 #if DMLC_USE_CXX11
93 
94 enum class FnProperty {
96  kNormal,
100  kCopyToGPU,
104  kAsync,
106  kDeleteVar,
110  kNoSkip
111 }; // enum class FnProperty
112 
117  public:
121  typedef std::function<void(RunContext)> SyncFn;
123  typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
135  virtual void NotifyShutdown() = 0;
139  virtual void Stop() {
140  LOG(FATAL) << "Engine cannot be stopped";
141  }
145  virtual void Start() {
146  LOG(FATAL) << "Engine cannot be restarted";
147  }
154  virtual VarHandle NewVariable() = 0;
167  virtual OprHandle NewOperator(AsyncFn fn,
168  std::vector<VarHandle> const& const_vars,
169  std::vector<VarHandle> const& mutable_vars,
171  const char* opr_name = nullptr,
172  bool wait = false) = 0;
180  virtual void DeleteOperator(OprHandle op) = 0;
188  virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0;
203  virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
204  std::vector<VarHandle> const& const_vars,
205  std::vector<VarHandle> const& mutable_vars,
207  int priority = 0,
208  const char* opr_name = nullptr,
209  bool wait = false) = 0;
221  virtual void DeleteVariable(SyncFn delete_fn,
222  Context exec_ctx,
223  VarHandle var) = 0;
229  virtual void WaitForVar(VarHandle var) = 0;
233  virtual void WaitForAll() = 0;
235  virtual void Throw(VarHandle var) = 0;
237  virtual ~Engine() noexcept(false) {}
241  static Engine* Get();
250  static std::shared_ptr<Engine> _GetSharedRef();
263  virtual void PushSync(SyncFn exec_fn, Context exec_ctx,
264  std::vector<VarHandle> const& const_vars,
265  std::vector<VarHandle> const& mutable_vars,
267  int priority = 0,
268  const char* opr_name = nullptr) {
269  this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
270  exec_fn(ctx);
271  on_complete();
272  }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
273  }
274 
280  inline CallbackOnComplete CreateCallback(
281  void (*callback)(Engine *, void *, const dmlc::Error *), void *param) {
282  CallbackOnComplete ret;
283  ret.callback_ = callback;
284  ret.engine_ = this;
285  ret.param_ = param;
286  return ret;
287  }
288  // For each var vector, sort it and remove the duplicated vars.
289  // Also remove vars from read_vars if it also appears in write_vars
290  inline void DeduplicateVarHandle(std::vector<engine::VarHandle> *read_vars,
291  std::vector<engine::VarHandle> *write_vars) {
292  std::sort(write_vars->begin(), write_vars->end());
293  write_vars->resize(std::unique(write_vars->begin(), write_vars->end()) -
294  write_vars->begin());
295  std::sort(read_vars->begin(), read_vars->end());
296  read_vars->resize(std::unique(read_vars->begin(), read_vars->end()) -
297  read_vars->begin());
298  auto wit = write_vars->begin();
299  auto rtop = read_vars->begin();
300  for (auto rit = read_vars->begin(); rit != read_vars->end(); ++rit) {
301  while (wit != write_vars->end() && *wit < *rit) ++wit;
302  if (wit == write_vars->end() || *wit != *rit) {
303  *rtop = *rit;
304  ++rtop;
305  }
306  }
307  read_vars->resize(rtop - read_vars->begin());
308  }
310  virtual int bulk_size() const {
311  return 0;
312  }
314  virtual int set_bulk_size(int) {
315  return 0;
316  }
317 }; // class Engine
318 #endif // DMLC_USE_CXX11
319 } // namespace mxnet
320 #endif // MXNET_ENGINE_H_
void DeduplicateVarHandle(std::vector< engine::VarHandle > *read_vars, std::vector< engine::VarHandle > *write_vars)
Definition: engine.h:290
FnProperty
Function property, used to hint what action is pushed to engine.
Definition: engine.h:94
virtual void Stop()
Stop all workers in the engine.
Definition: engine.h:139
std::function< void(RunContext)> SyncFn
Synchronous operation to pass to engine.
Definition: engine.h:121
virtual int bulk_size() const
query current limit for bulk size
Definition: engine.h:310
std::function< void(RunContext, CallbackOnComplete)> AsyncFn
Asynchronous operation to pass to engine.
Definition: engine.h:123
Operation not to be skipped even with associated exception.
size_t version_
version number of the var. Every time the object it is associated with is modified, the version number is incremented by 1.
Definition: engine.h:59
namespace of mxnet
Definition: api_registry.h:33
Asynchronous function call.
execution time context. The information needed in runtime for actual execution.
Definition: base.h:349
base class of engine variables.
Definition: engine.h:43
Delete variable call.
Normal operation.
virtual ~Var()=default
virtual ~Engine() noexcept(false)
virtual destructor
Definition: engine.h:237
CallbackOnComplete CreateCallback(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)
factory function to create OnComplete callback.
Definition: engine.h:280
Prioritized sync operation on GPU.
virtual void Start()
Restart all workers in the engine.
Definition: engine.h:145
Copy operation from GPU to other devices.
virtual int set_bulk_size(int)
set maximum limit for bulk size
Definition: engine.h:314
engine::OprHandle OprHandle
Operator pointer.
Definition: engine.h:127
engine::VarHandle VarHandle
Variable pointer.
Definition: engine.h:125
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies.
Definition: engine.h:63
Prioritized sync operation on CPU.
engine::CallbackOnComplete CallbackOnComplete
callback on complete
Definition: engine.h:119
virtual void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr)
Push an synchronous operation to the engine.
Definition: engine.h:263
Dependency engine that schedules operations.
Definition: engine.h:116
void operator()(const dmlc::Error *error=nullptr) const
involve the callback
Definition: engine.h:76
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:72
T * Cast()
cast variable to derived type T
Context information about the execution environment.
Definition: base.h:101
#define MXNET_API
define compatible keywords in g++ Used to support g++-4.6 and g++4.7
Definition: base.h:62
Copy operation from CPU to other devices.
Opr * OprHandle
Operator pointer type, usually hold by user.
Definition: engine.h:67
virtual size_t version()
Definition: engine.h:44