31 #include <unordered_map> 32 #include <unordered_set> 48 Chunk(
const char* source,
49 const std::vector<std::string>& options,
50 const std::vector<std::string>& exports);
59 CUfunction GetFunction(
const std::string& mangled_name,
const Context& ctx);
65 std::unordered_map<int, CUmodule> mod_;
67 std::unordered_set<std::string> exports_;
70 std::shared_ptr<Chunk> ptr_;
86 void Launch(
const Context& ctx,
const std::vector<dmlc::any>& args,
87 uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
88 uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
91 const std::vector<ArgType>&
signature() {
return signature_; }
101 Kernel(
const std::shared_ptr<Chunk>& mod,
102 const std::string& mangled_name,
103 const std::vector<ArgType>& signature);
105 std::string mangled_name_;
107 std::vector<ArgType> signature_;
109 std::shared_ptr<Chunk> mod_;
111 std::unordered_map<int, CUfunction> func_;
119 const std::vector<std::string>& options,
120 const std::vector<std::string>& exports)
121 : ptr_(std::make_shared<Chunk>(source, options, exports)) {}
128 std::shared_ptr<Kernel>
GetKernel(
const std::string& name,
129 const std::vector<ArgType>& signature);
135 #endif // MXNET_USE_CUDA 136 #endif // MXNET_RTC_H_ bool is_const
whether argument is constant (input)
Definition: rtc.h:78
Cuda runtime compile module.
Definition: rtc.h:39
namespace of mxnet
Definition: base.h:126
std::shared_ptr< Kernel > GetKernel(const std::string &name, const std::vector< ArgType > &signature)
Get cuda kernal from module by name.
CudaModule(const char *source, const std::vector< std::string > &options, const std::vector< std::string > &exports)
CudaModule constructor.
Definition: rtc.h:118
const std::vector< ArgType > & signature()
kernel interface signature
Definition: rtc.h:91
mshadow::TypeFlag dtype
data type of argument
Definition: rtc.h:80
Cuda kernel.
Definition: rtc.h:83
cuda kernel argument descriptor
Definition: rtc.h:74
Context information about the execution environment.
Definition: base.h:141
bool is_ndarray
whether argument is NDArray
Definition: rtc.h:76