mxnet.rtc¶
Interface to runtime cuda kernel compile module.
Classes
|
Constructs CUDA kernel. |
|
Compile and run CUDA code from Python. |
-
class
mxnet.rtc.
CudaKernel
(handle, name, is_ndarray, dtypes)[source]¶ Bases:
object
Constructs CUDA kernel. Should be created by CudaModule.get_kernel, not intended to be used by users.
Methods
launch
(args, ctx, grid_dims, block_dims[, …])Launch cuda kernel.
-
launch
(args, ctx, grid_dims, block_dims, shared_mem=0)[source]¶ Launch cuda kernel.
- Parameters
args (tuple of NDArray or numbers) – List of arguments for kernel. NDArrays are expected for pointer types (e.g. float*, double*) while numbers are expected for non-pointer types (e.g. int, float).
ctx (Context) – The context to launch kernel on. Must be GPU context.
grid_dims (tuple of 3 integers) – Grid dimensions for CUDA kernel.
block_dims (tuple of 3 integers) – Block dimensions for CUDA kernel.
shared_mem (integer, optional) – Size of dynamically allocated shared memory. Defaults to 0.
-
-
class
mxnet.rtc.
CudaModule
(source, options=(), exports=())[source]¶ Bases:
object
Compile and run CUDA code from Python.
In CUDA 7.5, you need to prepend your kernel definitions with ‘extern “C”’ to avoid name mangling:
source = r''' extern "C" __global__ void axpy(const float *x, float *y, float alpha) { int i = threadIdx.x + blockIdx.x * blockDim.x; y[i] += alpha * x[i]; } ''' module = mx.rtc.CudaModule(source) func = module.get_kernel("axpy", "const float *x, float *y, float alpha") x = mx.nd.ones((10,), ctx=mx.gpu(0)) y = mx.nd.zeros((10,), ctx=mx.gpu(0)) func.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1)) print(y)
Methods
get_kernel
(name, signature)Get CUDA kernel from compiled module.
Starting from CUDA 8.0, you can instead export functions by name. This also allows you to use templates:
source = r''' template<typename DType> __global__ void axpy(const DType *x, DType *y, DType alpha) { int i = threadIdx.x + blockIdx.x * blockDim.x; y[i] += alpha * x[i]; } ''' module = mx.rtc.CudaModule(source, exports=['axpy<float>', 'axpy<double>']) func32 = module.get_kernel("axpy<float>", "const float *x, float *y, float alpha") x = mx.nd.ones((10,), dtype='float32', ctx=mx.gpu(0)) y = mx.nd.zeros((10,), dtype='float32', ctx=mx.gpu(0)) func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1)) print(y) func64 = module.get_kernel("axpy<double>", "const double *x, double *y, double alpha") x = mx.nd.ones((10,), dtype='float64', ctx=mx.gpu(0)) y = mx.nd.zeros((10,), dtype='float64', ctx=mx.gpu(0)) func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1)) print(y)
- Parameters
source (str) – Complete source code.
options (tuple of str) – Compiler flags. For example, use “-I/usr/local/cuda/include” to add cuda headers to include path.
exports (tuple of str) – Export kernel names.
-
get_kernel
(name, signature)[source]¶ Get CUDA kernel from compiled module.
- Parameters
name (str) – String name of the kernel.
signature (str) –
Function signature for the kernel. For example, if a kernel is declared as:
extern "C" __global__ void axpy(const float *x, double *y, int alpha)
Then its signature should be:
const float *x, double *y, int alpha
or:
const float *, double *, int
Note that * in signature marks an argument as array and const marks an argument as constant (input) array.
- Returns
CUDA kernels that can be launched on GPUs.
- Return type