mxnet.npx.gpu

gpu(device_id=0)

Returns a GPU device.

This function is a short cut for Device(‘gpu’, device_id). The K GPUs on a node are typically numbered as 0,…,K-1.

Examples

>>> cpu_array = mx.np.ones((2, 3))
>>> cpu_array.device
cpu(0)
>>> with mx.gpu(1):
...     gpu_array = mx.np.ones((2, 3))
>>> gpu_array.device
gpu(1)
>>> gpu_array = mx.np.ones((2, 3), ctx=mx.gpu(1))
>>> gpu_array.device
gpu(1)
Parameters

device_id (int, optional) – The device id of the device, needed for GPU.

Returns

device – The corresponding GPU device.

Return type

Device