Source code for mxnet.dlpack
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=protected-access
# pylint: disable=import-error, no-name-in-module, undefined-variable
"""DLPack API of MXNet."""
import ctypes
import enum
from mxnet.device import current_device
from .base import _LIB, c_str, check_call, NDArrayHandle, mx_int
DLPackHandle = ctypes.c_void_p
PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
def _dlpack_deleter(pycapsule):
pycapsule = ctypes.c_void_p(pycapsule)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.c_void_p(
ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor))
check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr))
_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter)
class DLDeviceType(enum.IntEnum):
DLCPU = 1,
DLGPU = 2,
DLCPUPINNED = 3,
DLOPENCL = 4,
DLVULKAN = 7,
DLMETAL = 8,
DLVPI = 9,
DLROCM = 10,
DLEXTDEV = 12,
class DLContext(ctypes.Structure):
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)]
class DLDataType(ctypes.Structure):
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
TYPE_MAP = {
"int32": (0, 32, 1),
"int64": (0, 64, 1),
"bool": (1, 1, 1),
"uint8": (1, 8, 1),
"uint32": (1, 32, 1),
"uint64": (1, 64, 1),
'float16': (2, 16, 1),
"float32": (2, 32, 1),
"float64": (2, 64, 1),
}
class DLTensor(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", DLContext),
("ndim", ctypes.c_int),
("dtype", DLDataType),
("shape", ctypes.POINTER(ctypes.c_int64)),
("strides", ctypes.POINTER(ctypes.c_int64)),
("byte_offset", ctypes.c_uint64)]
class DLManagedTensor(ctypes.Structure):
pass
DeleterFunc = ctypes.CFUNCTYPE(None, ctypes.POINTER(DLManagedTensor))
DLManagedTensor._fields_ = [("dl_tensor", DLTensor), # pylint: disable=protected-access
("manager_ctx", ctypes.c_void_p),
("deleter", DeleterFunc)]
@DeleterFunc
def dl_managed_tensor_deleter(dl_managed_tensor_handle):
void_p = dl_managed_tensor_handle.contents.manager_ctx
pyobj = ctypes.cast(void_p, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
def ndarray_from_dlpack(array_cls):
"""Returns a function that returns specified array_cls from dlpack.
Returns
-------
fn : dlpack -> array_cls
"""
def from_dlpack(dlpack):
tp = type(dlpack)
if tp.__module__ == "builtins" and tp.__name__ == "PyCapsule":
dlpack = ctypes.py_object(dlpack)
elif hasattr(dlpack, "__dlpack__"):
device, device_id = dlpack.__dlpack_device__()
if device != DLDeviceType.DLGPU:
dlpack = ctypes.py_object(dlpack.__dlpack__())
else:
s = mx_int()
check_call(_LIB.MXGetCurrentStream(
ctypes.c_int(device_id), ctypes.byref(s)))
dlpack = ctypes.py_object(dlpack.__dlpack__(stream=s.value))
else:
raise AttributeError("Required PyCapsule or object with __dlpack__")
handle = NDArrayHandle()
assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError(
'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, False, ctypes.byref(handle)))
# Rename PyCapsule (DLPack)
ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor)
# delete the deleter of the old dlpack
ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None)
return array_cls(handle=handle)
return from_dlpack
def ndarray_to_dlpack_for_read():
"""Returns a function that returns dlpack for reading from mxnet array.
Returns
-------
fn : tensor -> dlpack
"""
def to_dlpack_for_read(data):
data.wait_to_read()
dlpack = DLPackHandle()
check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)
return to_dlpack_for_read
def ndarray_to_dlpack_for_write():
"""Returns a function that returns dlpack for writing from mxnet array.
Returns
-------
fn : tensor -> dlpack
"""
def to_dlpack_for_write(data):
check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
dlpack = DLPackHandle()
check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)
return to_dlpack_for_write
def ndarray_from_numpy(array_cls, array_create_fn):
"""Returns a function that creates array_cls from numpy array.
Returns
-------
fn : tensor -> dlpack
"""
def from_numpy(ndarray, zero_copy=True):
def _make_manager_ctx(obj):
pyobj = ctypes.py_object(obj)
void_p = ctypes.c_void_p.from_buffer(pyobj)
ctypes.pythonapi.Py_IncRef(pyobj)
return void_p
def _make_dl_tensor(array):
if str(array.dtype) not in DLDataType.TYPE_MAP:
raise ValueError(str(array.dtype) + " is not supported.")
dl_tensor = DLTensor()
dl_tensor.data = array.ctypes.data_as(ctypes.c_void_p)
dl_tensor.ctx = DLContext(1, 0)
dl_tensor.ndim = array.ndim
dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)]
dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64)
dl_tensor.strides = None
dl_tensor.byte_offset = 0
return dl_tensor
def _make_dl_managed_tensor(array):
c_obj = DLManagedTensor()
c_obj.dl_tensor = _make_dl_tensor(array)
c_obj.manager_ctx = _make_manager_ctx(array)
c_obj.deleter = dl_managed_tensor_deleter
return c_obj
if not zero_copy:
return array_create_fn(ndarray, dtype=ndarray.dtype)
if not ndarray.flags['C_CONTIGUOUS']:
raise ValueError("Only c-contiguous arrays are supported for zero-copy")
ndarray.flags['WRITEABLE'] = False
c_obj = _make_dl_managed_tensor(ndarray)
handle = NDArrayHandle()
check_call(_LIB.MXNDArrayFromDLPack(ctypes.byref(c_obj), True, ctypes.byref(handle)))
return array_cls(handle=handle)
return from_numpy
Did this page help you?
Yes
No
Thanks for your feedback!