mxnet
Classes | Public Member Functions | List of all members
mxnet::Imperative::CachedOp Class Reference

#include <imperative.h>

Collaboration diagram for mxnet::Imperative::CachedOp:
Collaboration graph

Public Member Functions

 CachedOp (const nnvm::Symbol &sym, const std::vector< std::pair< std::string, std::string > > &kwargs)
 
uint32_t num_inputs ()
 
uint32_t num_outputs ()
 
uint32_t num_backward_inputs ()
 
std::vector< bool > & save_inputs ()
 
std::vector< bool > & save_outputs ()
 
const std::unordered_set< uint32_t > & mutable_input_nodes ()
 
nnvm::Graph GetForwardGraph (const bool recording, const std::vector< NDArray * > &inputs)
 
nnvm::Graph GetBackwardGraph (const OpStatePtr &state, const std::vector< OpReqType > &reqs, const std::vector< NDArray * > &inputs)
 
std::vector< nnvm::NodeEntry > Gradient (const nnvm::NodePtr &node, const std::vector< nnvm::NodeEntry > &ograds)
 
void Forward (const std::shared_ptr< CachedOp > &op_ptr, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
 
void Backward (const bool retain_graph, const OpStatePtr &state, const std::vector< NDArray * > &inputs, const std::vector< OpReqType > &reqs, const std::vector< NDArray * > &outputs)
 

Constructor & Destructor Documentation

mxnet::Imperative::CachedOp::CachedOp ( const nnvm::Symbol &  sym,
const std::vector< std::pair< std::string, std::string > > &  kwargs 
)

Member Function Documentation

void mxnet::Imperative::CachedOp::Backward ( const bool  retain_graph,
const OpStatePtr state,
const std::vector< NDArray * > &  inputs,
const std::vector< OpReqType > &  reqs,
const std::vector< NDArray * > &  outputs 
)
void mxnet::Imperative::CachedOp::Forward ( const std::shared_ptr< CachedOp > &  op_ptr,
const std::vector< NDArray * > &  inputs,
const std::vector< NDArray * > &  outputs 
)
nnvm::Graph mxnet::Imperative::CachedOp::GetBackwardGraph ( const OpStatePtr state,
const std::vector< OpReqType > &  reqs,
const std::vector< NDArray * > &  inputs 
)
nnvm::Graph mxnet::Imperative::CachedOp::GetForwardGraph ( const bool  recording,
const std::vector< NDArray * > &  inputs 
)
std::vector<nnvm::NodeEntry> mxnet::Imperative::CachedOp::Gradient ( const nnvm::NodePtr &  node,
const std::vector< nnvm::NodeEntry > &  ograds 
)
const std::unordered_set<uint32_t>& mxnet::Imperative::CachedOp::mutable_input_nodes ( )
inline
uint32_t mxnet::Imperative::CachedOp::num_backward_inputs ( )
inline
uint32_t mxnet::Imperative::CachedOp::num_inputs ( )
inline
uint32_t mxnet::Imperative::CachedOp::num_outputs ( )
inline
std::vector<bool>& mxnet::Imperative::CachedOp::save_inputs ( )
inline
std::vector<bool>& mxnet::Imperative::CachedOp::save_outputs ( )
inline

The documentation for this class was generated from the following file: