#include <imperative.h>
mxnet::Imperative::CachedOp::CachedOp |
( |
const nnvm::Symbol & |
sym | ) |
|
|
explicit |
void mxnet::Imperative::CachedOp::Backward |
( |
const bool |
retain_graph, |
|
|
const OpStatePtr & |
state, |
|
|
const std::vector< NDArray * > & |
inputs, |
|
|
const std::vector< OpReqType > & |
reqs, |
|
|
const std::vector< NDArray * > & |
outputs |
|
) |
| |
OpStatePtr mxnet::Imperative::CachedOp::Forward |
( |
const std::vector< NDArray * > & |
inputs, |
|
|
const std::vector< NDArray * > & |
outputs |
|
) |
| |
nnvm::Graph mxnet::Imperative::CachedOp::GetBackwardGraph |
( |
const OpStatePtr & |
state, |
|
|
const std::vector< OpReqType > & |
reqs, |
|
|
const std::vector< NDArray * > & |
inputs |
|
) |
| |
nnvm::Graph mxnet::Imperative::CachedOp::GetForwardGraph |
( |
const bool |
recording, |
|
|
const std::vector< NDArray * > & |
inputs |
|
) |
| |
std::vector<nnvm::NodeEntry> mxnet::Imperative::CachedOp::Gradient |
( |
const nnvm::NodePtr & |
node, |
|
|
const std::vector< nnvm::NodeEntry > & |
ograds |
|
) |
| |
const std::unordered_set<uint32_t>& mxnet::Imperative::CachedOp::mutable_input_nodes |
( |
| ) |
|
|
inline |
uint32_t mxnet::Imperative::CachedOp::num_backward_inputs |
( |
| ) |
|
|
inline |
uint32_t mxnet::Imperative::CachedOp::num_inputs |
( |
| ) |
|
|
inline |
uint32_t mxnet::Imperative::CachedOp::num_outputs |
( |
| ) |
|
|
inline |
std::vector<bool>& mxnet::Imperative::CachedOp::save_inputs |
( |
| ) |
|
|
inline |
std::vector<bool>& mxnet::Imperative::CachedOp::save_outputs |
( |
| ) |
|
|
inline |
The documentation for this class was generated from the following file: