mxnet.npx.pick

pick(data, index, axis=-1, mode='clip', keepdims=False)

Picks elements from an input array according to the input indices along the given axis.

Given an input array of shape (d0, d1) and indices of shape (i0,), the result will be an output array of shape (i0,) with:

output[i] = input[i, indices[i]]

By default, if any index mentioned is too large, it is replaced by the index that addresses the last element along an axis (the clip mode).

This function supports n-dimensional input and (n-1)-dimensional indices arrays.

Parameters
  • data (NDArray) – The input array

  • index (NDArray) – The index array

  • axis (int or None, optional, default='-1') – int or None. The axis to picking the elements. Negative values means indexing from right to left. If is None, the elements in the index w.r.t the flattened input will be picked.

  • keepdims (boolean, optional, default=0) – If true, the axis where we pick the elements is left in the result as dimension with size one.

  • mode ({'clip', 'wrap'},optional, default='clip') – Specify how out-of-bound indices behave. Default is “clip”. “clip” means clip to the range. So, if all indices mentioned are too large, they are replaced by the index that addresses the last element along an axis. “wrap” means to wrap around.

  • out (NDArray, optional) – The output NDArray to hold the result.

Returns

out – The output of this function.

Return type

NDArray or list of NDArrays

Example

>>> x = np.array([[1., 2.],[3., 4.],[5., 6.]])

picks elements with specified indices along axis 0

>>> npx.pick(x, np.array([0, 1]), 0)
array([1., 4.])

picks elements with specified indices along axis 1

>>> npx.pick(x, np.array([0, 1, 0]), 1)
array([1., 4., 5.])

picks elements with specified indices along axis 1 using ‘wrap’ mode to place indicies that would normally be out of bounds

>>> npx.pick(x, np.array([2, -1, -2]), 1, mode='wrap')
array([1., 4., 5.])

picks elements with specified indices along axis 1 and dims are maintained

>>> npx.pick(x, np.array([[1.], [0.], [2.]]), 1, keepdims=True)
array([[2.],
       [3.],
       [6.]])