mxnet
rtc.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_RTC_H_
21 #define MXNET_RTC_H_
22 #include "./base.h"
23 #if MXNET_USE_CUDA
24 #include <nvrtc.h>
25 #include <cuda.h>
26 
27 #include <vector>
28 #include <string>
29 #include <memory>
30 #include <utility>
31 #include <unordered_map>
32 #include <unordered_set>
33 #include "./ndarray.h"
34 
35 namespace mxnet {
36 namespace rtc {
37 
39 class CudaModule {
40  private:
42  struct Chunk {
48  Chunk(const char* source,
49  const std::vector<std::string>& options,
50  const std::vector<std::string>& exports);
52  ~Chunk();
59  CUfunction GetFunction(const std::string& mangled_name, const Context& ctx);
61  nvrtcProgram prog_;
63  char* ptx_;
65  std::unordered_map<int, CUmodule> mod_;
67  std::unordered_set<std::string> exports_;
68  };
70  std::shared_ptr<Chunk> ptr_;
71 
72  public:
74  struct ArgType {
76  bool is_ndarray;
78  bool is_const;
80  mshadow::TypeFlag dtype;
81  };
83  class Kernel {
84  public:
86  void Launch(const Context& ctx, const std::vector<dmlc::any>& args,
87  uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
88  uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
89  uint32_t shared_mem);
91  const std::vector<ArgType>& signature() { return signature_; }
92 
93  private:
94  friend class CudaModule;
101  Kernel(const std::shared_ptr<Chunk>& mod,
102  const std::string& mangled_name,
103  const std::vector<ArgType>& signature);
105  std::string mangled_name_;
107  std::vector<ArgType> signature_;
109  std::shared_ptr<Chunk> mod_;
111  std::unordered_map<int, CUfunction> func_;
112  };
118  CudaModule(const char* source,
119  const std::vector<std::string>& options,
120  const std::vector<std::string>& exports)
121  : ptr_(std::make_shared<Chunk>(source, options, exports)) {}
128  std::shared_ptr<Kernel> GetKernel(const std::string& name,
129  const std::vector<ArgType>& signature);
130 };
131 
132 } // namespace rtc
133 } // namespace mxnet
134 
135 #endif // MXNET_USE_CUDA
136 #endif // MXNET_RTC_H_
bool is_const
whether argument is constant (input)
Definition: rtc.h:78
Cuda runtime compile module.
Definition: rtc.h:39
namespace of mxnet
Definition: base.h:126
std::shared_ptr< Kernel > GetKernel(const std::string &name, const std::vector< ArgType > &signature)
Get cuda kernal from module by name.
CudaModule(const char *source, const std::vector< std::string > &options, const std::vector< std::string > &exports)
CudaModule constructor.
Definition: rtc.h:118
const std::vector< ArgType > & signature()
kernel interface signature
Definition: rtc.h:91
mshadow::TypeFlag dtype
data type of argument
Definition: rtc.h:80
Cuda kernel.
Definition: rtc.h:83
cuda kernel argument descriptor
Definition: rtc.h:74
Context information about the execution environment.
Definition: base.h:141
bool is_ndarray
whether argument is NDArray
Definition: rtc.h:76