#include <imperative.h>
|
| CachedOp (const nnvm::Symbol &sym, const std::vector< std::pair< std::string, std::string > > &kwargs) |
|
uint32_t | num_inputs () |
|
uint32_t | num_outputs () |
|
uint32_t | num_backward_inputs () |
|
std::vector< bool > & | save_inputs () |
|
std::vector< bool > & | save_outputs () |
|
const std::unordered_set< uint32_t > & | mutable_input_nodes () |
|
nnvm::Graph | GetForwardGraph (const bool recording, const std::vector< NDArray * > &inputs) |
|
nnvm::Graph | GetBackwardGraph (const OpStatePtr &state, const std::vector< OpReqType > &reqs, const std::vector< NDArray * > &inputs) |
|
std::vector< nnvm::NodeEntry > | Gradient (const nnvm::NodePtr &node, const std::vector< nnvm::NodeEntry > &ograds) |
|
void | Forward (const std::shared_ptr< CachedOp > &op_ptr, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs) |
|
void | Backward (const bool retain_graph, const OpStatePtr &state, const std::vector< NDArray * > &inputs, const std::vector< OpReqType > &reqs, const std::vector< NDArray * > &outputs) |
|
mxnet::Imperative::CachedOp::CachedOp |
( |
const nnvm::Symbol & |
sym, |
|
|
const std::vector< std::pair< std::string, std::string > > & |
kwargs |
|
) |
| |
void mxnet::Imperative::CachedOp::Backward |
( |
const bool |
retain_graph, |
|
|
const OpStatePtr & |
state, |
|
|
const std::vector< NDArray * > & |
inputs, |
|
|
const std::vector< OpReqType > & |
reqs, |
|
|
const std::vector< NDArray * > & |
outputs |
|
) |
| |
void mxnet::Imperative::CachedOp::Forward |
( |
const std::shared_ptr< CachedOp > & |
op_ptr, |
|
|
const std::vector< NDArray * > & |
inputs, |
|
|
const std::vector< NDArray * > & |
outputs |
|
) |
| |
nnvm::Graph mxnet::Imperative::CachedOp::GetBackwardGraph |
( |
const OpStatePtr & |
state, |
|
|
const std::vector< OpReqType > & |
reqs, |
|
|
const std::vector< NDArray * > & |
inputs |
|
) |
| |
nnvm::Graph mxnet::Imperative::CachedOp::GetForwardGraph |
( |
const bool |
recording, |
|
|
const std::vector< NDArray * > & |
inputs |
|
) |
| |
std::vector<nnvm::NodeEntry> mxnet::Imperative::CachedOp::Gradient |
( |
const nnvm::NodePtr & |
node, |
|
|
const std::vector< nnvm::NodeEntry > & |
ograds |
|
) |
| |
const std::unordered_set<uint32_t>& mxnet::Imperative::CachedOp::mutable_input_nodes |
( |
| ) |
|
|
inline |
uint32_t mxnet::Imperative::CachedOp::num_backward_inputs |
( |
| ) |
|
|
inline |
uint32_t mxnet::Imperative::CachedOp::num_inputs |
( |
| ) |
|
|
inline |
uint32_t mxnet::Imperative::CachedOp::num_outputs |
( |
| ) |
|
|
inline |
std::vector<bool>& mxnet::Imperative::CachedOp::save_inputs |
( |
| ) |
|
|
inline |
std::vector<bool>& mxnet::Imperative::CachedOp::save_outputs |
( |
| ) |
|
|
inline |
The documentation for this class was generated from the following file: