Source code for mxnet.ndarray.utils

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Utility functions for NDArray and BaseSparseNDArray."""
import ctypes

from ..base import _LIB, check_call, py_str, c_str, string_types, mx_uint, NDArrayHandle
from ..base import c_array, c_handle_array, c_str_array
from .ndarray import NDArray
from .ndarray import array as _array
from .ndarray import empty as _empty_ndarray
from .ndarray import zeros as _zeros_ndarray
from .sparse import zeros as _zeros_sparse_ndarray
from .sparse import empty as _empty_sparse_ndarray
from .sparse import array as _sparse_array
from .sparse import _ndarray_cls
try:
    import scipy.sparse as spsp
except ImportError:
    spsp = None

__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save']


[docs]def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs): """Return a new array of given shape and type, filled with zeros. Parameters ---------- shape : int or tuple of int The shape of the empty array ctx : Context, optional An optional device context (default is the current default context) dtype : str or numpy.dtype, optional An optional value type (default is `float32`) stype: string, optional The storage type of the empty array, such as 'row_sparse', 'csr', etc. Returns ------- NDArray, CSRNDArray or RowSparseNDArray A created array Examples -------- >>> mx.nd.zeros((1,2), mx.cpu(), stype='csr') <CSRNDArray 1x2 @cpu(0)> >>> mx.nd.zeros((1,2), mx.cpu(), 'float16', stype='row_sparse').asnumpy() array([[ 0., 0.]], dtype=float16) """ if stype is None or stype == 'default': return _zeros_ndarray(shape, ctx, dtype, **kwargs) else: return _zeros_sparse_ndarray(stype, shape, ctx, dtype, **kwargs)
[docs]def empty(shape, ctx=None, dtype=None, stype=None): """Returns a new array of given shape and type, without initializing entries. Parameters ---------- shape : int or tuple of int The shape of the empty array. ctx : Context, optional An optional device context (default is the current default context). dtype : str or numpy.dtype, optional An optional value type (default is `float32`). stype : str, optional An optional storage type (default is `default`). Returns ------- NDArray, CSRNDArray or RowSparseNDArray A created array. Examples -------- >>> mx.nd.empty(1) <NDArray 1 @cpu(0)> >>> mx.nd.empty((1,2), mx.gpu(0)) <NDArray 1x2 @gpu(0)> >>> mx.nd.empty((1,2), mx.gpu(0), 'float16') <NDArray 1x2 @gpu(0)> >>> mx.nd.empty((1,2), stype='csr') <CSRNDArray 1x2 @cpu(0)> """ if stype is None or stype == 'default': return _empty_ndarray(shape, ctx, dtype) else: return _empty_sparse_ndarray(stype, shape, ctx, dtype)
[docs]def array(source_array, ctx=None, dtype=None): """Creates an array from any object exposing the array interface. Parameters ---------- source_array : array_like An object exposing the array interface, an object whose `__array__` method returns an array, or any (nested) sequence. ctx : Context, optional Device context (default is the current default context). dtype : str or numpy.dtype, optional The data type of the output array. The default dtype is ``source_array.dtype`` if `source_array` is an `NDArray`, `float32` otherwise. Returns ------- NDArray, RowSparseNDArray or CSRNDArray An array with the same contents as the `source_array`. Examples -------- >>> import numpy as np >>> mx.nd.array([1, 2, 3]) <NDArray 3 @cpu(0)> >>> mx.nd.array([[1, 2], [3, 4]]) <NDArray 2x2 @cpu(0)> >>> mx.nd.array(np.zeros((3, 2))) <NDArray 3x2 @cpu(0)> >>> mx.nd.array(np.zeros((3, 2)), mx.gpu(0)) <NDArray 3x2 @gpu(0)> >>> mx.nd.array(mx.nd.zeros((3, 2), stype='row_sparse')) <RowSparseNDArray 3x2 @cpu(0)> """ if spsp is not None and isinstance(source_array, spsp.csr.csr_matrix): return _sparse_array(source_array, ctx=ctx, dtype=dtype) elif isinstance(source_array, NDArray) and source_array.stype != 'default': return _sparse_array(source_array, ctx=ctx, dtype=dtype) else: return _array(source_array, ctx=ctx, dtype=dtype)
[docs]def load(fname): """Loads an array from file. See more details in ``save``. Parameters ---------- fname : str The filename. Returns ------- list of NDArray, RowSparseNDArray or CSRNDArray, or \ dict of str to NDArray, RowSparseNDArray or CSRNDArray Loaded data. """ if not isinstance(fname, string_types): raise TypeError('fname required to be a string') out_size = mx_uint() out_name_size = mx_uint() handles = ctypes.POINTER(NDArrayHandle)() names = ctypes.POINTER(ctypes.c_char_p)() check_call(_LIB.MXNDArrayLoad(c_str(fname), ctypes.byref(out_size), ctypes.byref(handles), ctypes.byref(out_name_size), ctypes.byref(names))) if out_name_size.value == 0: return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] else: assert out_name_size.value == out_size.value return dict( (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) for i in range(out_size.value))
[docs]def load_frombuffer(buf): """Loads an array dictionary or list from a buffer See more details in ``save``. Parameters ---------- buf : str Buffer containing contents of a file as a string or bytes. Returns ------- list of NDArray, RowSparseNDArray or CSRNDArray, or \ dict of str to NDArray, RowSparseNDArray or CSRNDArray Loaded data. """ if not isinstance(buf, string_types + tuple([bytes])): raise TypeError('buf required to be a string or bytes') out_size = mx_uint() out_name_size = mx_uint() handles = ctypes.POINTER(NDArrayHandle)() names = ctypes.POINTER(ctypes.c_char_p)() check_call(_LIB.MXNDArrayLoadFromBuffer(buf, mx_uint(len(buf)), ctypes.byref(out_size), ctypes.byref(handles), ctypes.byref(out_name_size), ctypes.byref(names))) if out_name_size.value == 0: return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] else: assert out_name_size.value == out_size.value return dict( (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) for i in range(out_size.value))
[docs]def save(fname, data): """Saves a list of arrays or a dict of str->array to file. Parameters ---------- fname : str The filename. data : NDArray, RowSparseNDArray or CSRNDArray, \ or list of NDArray, RowSparseNDArray or CSRNDArray, \ or dict of str to NDArray, RowSparseNDArray or CSRNDArray The data to save. Examples -------- >>> x = mx.nd.zeros((2,3)) >>> y = mx.nd.ones((1,4)) >>> mx.nd.save('my_list', [x,y]) >>> mx.nd.save('my_dict', {'x':x, 'y':y}) >>> mx.nd.load('my_list') [<NDArray 2x3 @cpu(0)>, <NDArray 1x4 @cpu(0)>] >>> mx.nd.load('my_dict') {'y': <NDArray 1x4 @cpu(0)>, 'x': <NDArray 2x3 @cpu(0)>} """ from ..numpy import ndarray as np_ndarray if isinstance(data, NDArray): data = [data] handles = c_array(NDArrayHandle, []) if isinstance(data, dict): str_keys = data.keys() nd_vals = data.values() if any(not isinstance(k, string_types) for k in str_keys) or \ any(not isinstance(v, NDArray) for v in nd_vals): raise TypeError('save only accept dict str->NDArray or list of NDArray') if any(isinstance(v, np_ndarray) for v in nd_vals): raise TypeError('cannot save mxnet.numpy.ndarray using mxnet.ndarray.save;' ' use mxnet.numpy.save instead.') keys = c_str_array(str_keys) handles = c_handle_array(nd_vals) elif isinstance(data, list): if any(not isinstance(v, NDArray) for v in data): raise TypeError('save only accept dict str->NDArray or list of NDArray') if any(isinstance(v, np_ndarray) for v in data): raise TypeError('cannot save mxnet.numpy.ndarray using mxnet.ndarray.save;' ' use mxnet.numpy.save instead.') keys = None handles = c_handle_array(data) else: raise ValueError("data needs to either be a NDArray, dict of str, NDArray pairs " "or a list of NDarrays.") check_call(_LIB.MXNDArrayLegacySave(c_str(fname), mx_uint(len(handles)), handles, keys))