Source code for mxnet.dlpack

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=protected-access
# pylint: disable=import-error, no-name-in-module, undefined-variable

"""DLPack API of MXNet."""

import ctypes
import enum

from mxnet.device import current_device
from .base import _LIB, c_str, check_call, NDArrayHandle, mx_int

DLPackHandle = ctypes.c_void_p

PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')

def _dlpack_deleter(pycapsule):
    pycapsule = ctypes.c_void_p(pycapsule)
    if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
        ptr = ctypes.c_void_p(
            ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor))
        check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr))

_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter)

class DLDeviceType(enum.IntEnum):
    DLCPU = 1,
    DLGPU = 2,
    DLCPUPINNED = 3,
    DLOPENCL = 4,
    DLVULKAN = 7,
    DLMETAL = 8,
    DLVPI = 9,
    DLROCM = 10,
    DLEXTDEV = 12,


class DLContext(ctypes.Structure):
    _fields_ = [("device_type", ctypes.c_int),
                ("device_id", ctypes.c_int)]

class DLDataType(ctypes.Structure):
    _fields_ = [("type_code", ctypes.c_uint8),
                ("bits", ctypes.c_uint8),
                ("lanes", ctypes.c_uint16)]
    TYPE_MAP = {
        "int32": (0, 32, 1),
        "int64": (0, 64, 1),
        "bool": (1, 1, 1),
        "uint8": (1, 8, 1),
        "uint32": (1, 32, 1),
        "uint64": (1, 64, 1),
        'float16': (2, 16, 1),
        "float32": (2, 32, 1),
        "float64": (2, 64, 1),
    }


class DLTensor(ctypes.Structure):
    _fields_ = [("data", ctypes.c_void_p),
                ("ctx", DLContext),
                ("ndim", ctypes.c_int),
                ("dtype", DLDataType),
                ("shape", ctypes.POINTER(ctypes.c_int64)),
                ("strides", ctypes.POINTER(ctypes.c_int64)),
                ("byte_offset", ctypes.c_uint64)]

class DLManagedTensor(ctypes.Structure):
    pass


DeleterFunc = ctypes.CFUNCTYPE(None, ctypes.POINTER(DLManagedTensor))


DLManagedTensor._fields_ = [("dl_tensor", DLTensor),           # pylint: disable=protected-access
                            ("manager_ctx", ctypes.c_void_p),
                            ("deleter", DeleterFunc)]

@DeleterFunc
def dl_managed_tensor_deleter(dl_managed_tensor_handle):
    void_p = dl_managed_tensor_handle.contents.manager_ctx
    pyobj = ctypes.cast(void_p, ctypes.py_object)
    ctypes.pythonapi.Py_DecRef(pyobj)

def ndarray_from_dlpack(array_cls):
    """Returns a function that returns specified array_cls from dlpack.

    Returns
    -------
    fn : dlpack -> array_cls
    """
    def from_dlpack(dlpack):
        tp = type(dlpack)
        if tp.__module__ == "builtins" and tp.__name__ == "PyCapsule":
            dlpack = ctypes.py_object(dlpack)        
        elif hasattr(dlpack, "__dlpack__"):
            device, device_id = dlpack.__dlpack_device__()
            if device != DLDeviceType.DLGPU:
                dlpack = ctypes.py_object(dlpack.__dlpack__())
            else:
                s = mx_int()
                check_call(_LIB.MXGetCurrentStream(
                    ctypes.c_int(device_id), ctypes.byref(s)))
                dlpack = ctypes.py_object(dlpack.__dlpack__(stream=s.value))
        else:
            raise AttributeError("Required PyCapsule or object with __dlpack__")
        handle = NDArrayHandle()
        assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError(
            'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
        dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
        check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, False, ctypes.byref(handle)))
        # Rename PyCapsule (DLPack)
        ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor)
        # delete the deleter of the old dlpack
        ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None)
        return array_cls(handle=handle)
    return from_dlpack


def ndarray_to_dlpack_for_read():
    """Returns a function that returns dlpack for reading from mxnet array.

    Returns
    -------
    fn : tensor -> dlpack
    """
    def to_dlpack_for_read(data):
        data.wait_to_read()
        dlpack = DLPackHandle()
        check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
        return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)
    return to_dlpack_for_read

def ndarray_to_dlpack_for_write():
    """Returns a function that returns dlpack for writing from mxnet array.

    Returns
    -------
    fn : tensor -> dlpack
    """
    def to_dlpack_for_write(data):

        check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
        dlpack = DLPackHandle()
        check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
        return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)
    return to_dlpack_for_write

def ndarray_from_numpy(array_cls, array_create_fn):
    """Returns a function that creates array_cls from numpy array.

    Returns
    -------
    fn : tensor -> dlpack
    """
    def from_numpy(ndarray, zero_copy=True):
        def _make_manager_ctx(obj):
            pyobj = ctypes.py_object(obj)
            void_p = ctypes.c_void_p.from_buffer(pyobj)
            ctypes.pythonapi.Py_IncRef(pyobj)
            return void_p

        def _make_dl_tensor(array):
            if str(array.dtype) not in DLDataType.TYPE_MAP:
                raise ValueError(str(array.dtype) + " is not supported.")
            dl_tensor = DLTensor()
            dl_tensor.data = array.ctypes.data_as(ctypes.c_void_p)
            dl_tensor.ctx = DLContext(1, 0)
            dl_tensor.ndim = array.ndim
            dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)]
            dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64)
            dl_tensor.strides = None
            dl_tensor.byte_offset = 0
            return dl_tensor

        def _make_dl_managed_tensor(array):
            c_obj = DLManagedTensor()
            c_obj.dl_tensor = _make_dl_tensor(array)
            c_obj.manager_ctx = _make_manager_ctx(array)
            c_obj.deleter = dl_managed_tensor_deleter
            return c_obj

        if not zero_copy:
            return array_create_fn(ndarray, dtype=ndarray.dtype)

        if not ndarray.flags['C_CONTIGUOUS']:
            raise ValueError("Only c-contiguous arrays are supported for zero-copy")

        ndarray.flags['WRITEABLE'] = False
        c_obj = _make_dl_managed_tensor(ndarray)
        handle = NDArrayHandle()
        check_call(_LIB.MXNDArrayFromDLPack(ctypes.byref(c_obj), True, ctypes.byref(handle)))
        return array_cls(handle=handle)
    return from_numpy