25 #ifndef MXNET_MXRTC_H_ 26 #define MXNET_MXRTC_H_ 28 #if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) 36 #include <unordered_map> 37 #include "./ndarray.h" 56 MXRtc(
const std::string& name,
57 std::vector<std::pair<std::string, NDArray> >
const& input,
58 std::vector<std::pair<std::string, NDArray> >
const& output,
59 const std::string& kernel);
71 void push(std::vector<NDArray>
const& input,
72 std::vector<NDArray>
const& output,
73 unsigned int grid_dim_X,
74 unsigned int grid_dim_Y,
75 unsigned int grid_dim_Z,
76 unsigned int block_dim_X,
77 unsigned int block_dim_Y,
78 unsigned int block_dim_Z);
81 static const char str_type[];
82 static std::unordered_map<std::string, char*> kernel_registry;
85 index_t num_input_, num_output_;
88 std::unordered_map<int, CUmodule> module_;
89 std::unordered_map<int, CUfunction> func_;
94 std::string decorate(
const std::string& name,
95 std::vector<std::pair<std::string, NDArray> >
const& input,
96 std::vector<std::pair<std::string, NDArray> >
const& output,
97 const std::string kernel);
101 char* compile(
const std::string& name,
const std::string& code);
106 #endif // MXNET_USE_CUDA && MXNET_USE_NVRTC 107 #endif // MXNET_MXRTC_H_ namespace of mxnet
Definition: base.h:126
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:132