mxnet.np.dsplit

dsplit(ary, indices_or_sections)

Split array into multiple sub-arrays along the 3rd axis (depth). Please refer to the split documentation. dsplit is equivalent to split with axis=2, the array is always split along the third axis provided the array dimension is greater than or equal to 3.

Parameters
  • ary (ndarray) – Array to be divided into sub-arrays.

  • indices_or_sections (int or 1 - D Python tuple, list or set.) –

    If indices_or_sections is an integer, N, the array will be divided into N equal arrays along axis 2. If such a split is not possible, an error is raised.

    If indices_or_sections is a 1-D array of sorted integers, the entries indicate where along axis 2 the array is split. For example, [2, 3] would result in

    • ary[:, :, :2]

    • ary[:, :, 2:3]

    • ary[:, :, 3:]

    If an index exceeds the dimension of the array along axis 2, an error will be thrown.

Returns

sub-arrays – A list of sub-arrays.

Return type

list of ndarrays

See also

split()

Split an array into multiple sub-arrays of equal size.

()

This function differs from the original numpy.dsplit in the following aspects: * Currently parameter indices_or_sections does not support ndarray, but supports scalar, tuple and list. * In indices_or_sections, if an index exceeds the dimension of the array along axis 2, an error will be thrown.

Examples

>>> x = np.arange(16.0).reshape(2, 2, 4)
>>> x
array([[[ 0.,   1.,   2.,   3.],
        [ 4.,   5.,   6.,   7.]],
       [[ 8.,   9.,  10.,  11.],
        [12.,  13.,  14.,  15.]]])
>>> np.dsplit(x, 2)
[array([[[ 0.,  1.],
        [ 4.,  5.]],
       [[ 8.,  9.],
        [12., 13.]]]), array([[[ 2.,  3.],
        [ 6.,  7.]],
       [[10., 11.],
        [14., 15.]]])]
>>> np.dsplit(x, np.array([3, 6]))
[array([[[ 0.,   1.,   2.],
        [ 4.,   5.,   6.]],
       [[ 8.,   9.,  10.],
        [12.,  13.,  14.]]]),
 array([[[ 3.],
        [ 7.]],
       [[11.],
        [15.]]]),
array([], shape=(2, 2, 0), dtype=float64)]