Source code for mxnet._ctypes.cached_op

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments
# pylint: disable=global-statement, unused-import
"""CachedOp API."""

import ctypes

from ..base import _LIB
from ..base import c_handle_array
from ..base import NDArrayHandle, CachedOpHandle, SymbolHandle
from ..base import check_call
from .. import _global_var
from ..ndarray._internal import NDArrayBase
from . import _api_internal

def _monitor_callback_wrapper(callback):
    """A wrapper for the user-defined handle."""
    def callback_handle(name, opr_name, array, _):
        """ ctypes function """
        callback(name, opr_name, array)
    return callback_handle

[docs]class CachedOp(object): """Cached operator handle.""" __slots__ = ["handle", "is_np_sym", "_monitor_callback"] def __init__(self, sym, flags=(), thread_safe=False): self._monitor_callback = None from ..symbol.numpy._symbol import _Symbol self.is_np_sym = bool(isinstance(sym, _Symbol)) flags = {key: str(value) for key, value in flags} self.handle = CachedOpHandle(_api_internal.create( sym.handle, flags, thread_safe )) def __del__(self):
[docs] def get_optimized_symbol(self): """Get an optimized version of the symbol from the cached op. Returns ------- symbol : Symbol Optimized symbol from the executor. """ from ..symbol import Symbol sym_handle = SymbolHandle(_api_internal.get_optimized_symbol(self.handle)) ret = Symbol(sym_handle) return ret
def __call__(self, *args, **kwargs): """ctypes implementation of imperative invoke wrapper""" # New FFI only supports numpy ndarray default_device = kwargs.pop('default_device', None) if not default_device: default_device = kwargs.pop('default_ctx', None) out = kwargs.pop('out', None) if kwargs: raise TypeError( "CachedOp.__call__ got unexpected keyword argument(s): " + \ ', '.join(kwargs.keys())) if self.is_np_sym: if len(args) == 1 and args[0] is None: args = [] type_id = default_device.device_typeid if default_device else None device_id = default_device.device_id if default_device else None out_arg = out if out is not None and not isinstance(out, NDArrayBase) else (out, ) output_vars = _api_internal.invoke( self.handle, len(args), *args, type_id, device_id, *out_arg ) if out is not None: return out if isinstance(output_vars, NDArrayBase): return output_vars else: return list(output_vars) else: if out is not None: original_output = out if isinstance(out, NDArrayBase): out = (out,) num_output = ctypes.c_int(len(out)) output_vars = c_handle_array(out) output_vars = ctypes.cast(output_vars, ctypes.POINTER(NDArrayHandle)) else: original_output = None output_vars = ctypes.POINTER(NDArrayHandle)() num_output = ctypes.c_int(0) # return output stypes to avoid the c_api call for checking # a handle's stype in _ndarray_cls out_stypes = ctypes.POINTER(ctypes.c_int)() # (None, ) -> [] if len(args) == 1 and args[0] is None: args = [] assert default_device is not None, 'default_device is required if no input is provided' else: default_device = args[0].device if default_device is None else default_device check_call(_LIB.MXInvokeCachedOp( self.handle, ctypes.c_int(len(args)), c_handle_array(args), ctypes.c_int(default_device.device_typeid), ctypes.c_int(default_device.device_id), ctypes.byref(num_output), ctypes.byref(output_vars), ctypes.byref(out_stypes))) if original_output is not None: return original_output create_ndarray_fn = _global_var._ndarray_cls if num_output.value == 1: return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle), stype=out_stypes[0]) else: return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i]) for i in range(num_output.value)] def _register_op_hook(self, callback, monitor_all=False): """Install callback for monitor. Parameters ---------- callback : function Takes a string for node_name, string for op_name and a NDArrayHandle. monitor_all : bool, default False If true, monitor both input _imperative_invoked output, otherwise monitor output only. """ cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p) if callback: self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) callback_ptr = ctypes.cast(self._monitor_callback, ctypes.c_void_p) _api_internal.register_op_hook(self.handle, callback_ptr, monitor_all)