23 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC 31 #include <unordered_map> 32 #include <unordered_set> 33 #include "./ndarray.h" 48 Chunk(
const char* source,
49 const std::vector<std::string>& options,
50 const std::vector<std::string>& exports);
59 CUfunction GetFunction(
const std::string& mangled_name,
const Context& ctx);
65 std::unordered_map<int, CUmodule> mod_;
67 std::unordered_set<std::string> exports_;
70 std::shared_ptr<Chunk> ptr_;
80 mshadow::TypeFlag dtype;
86 void Launch(
const Context& ctx,
const std::vector<dmlc::any>& args,
87 uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
88 uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
91 const std::vector<ArgType>& signature() {
return signature_; }
94 friend class CudaModule;
101 Kernel(
const std::shared_ptr<Chunk>& mod,
102 const std::string& mangled_name,
103 const std::vector<ArgType>& signature);
105 std::string mangled_name_;
107 std::vector<ArgType> signature_;
109 std::shared_ptr<Chunk> mod_;
111 std::unordered_map<int, CUfunction> func_;
118 CudaModule(
const char* source,
119 const std::vector<std::string>& options,
120 const std::vector<std::string>& exports)
121 : ptr_(std::make_shared<Chunk>(source, options, exports)) {}
128 std::shared_ptr<Kernel> GetKernel(
const std::string& name,
129 const std::vector<ArgType>& signature);
135 #endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC 136 #endif // MXNET_RTC_H_ namespace of mxnet
Definition: base.h:127