argmax(a, axis=None, out=None, keepdims=False)

Returns the indices of the maximum values along an axis.

  • a (ndarray) – Input array. Only support ndarrays of dtype float16, float32, and float64.

  • axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.

  • out (ndarray or None, optional) – If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.

  • keepdims (bool) – If True, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array. Otherwise, if False, the reduced axes (dimensions) must not be included in the result. Default: False .


  • index_array (ndarray of indices whose dtype is same as the input ndarray.) – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

  • .. note::keepdims param is part of request in data-api-standard <>`_, which is not the parameter in official NumPy

    In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

    This function differs from the original numpy.argmax in the following aspects:

    • Input type does not support Python native iterables(list, tuple, …).

    • out param: cannot perform auto broadcasting. out ndarray’s shape must be the same as the expected output.

    • out param: cannot perform auto type cast. out ndarray’s dtype must be the same as the expected output.

    • out param does not support scalar input case.


>>> a = np.arange(6).reshape(2,3) + 10
>>> a
array([[10., 11., 12.],
       [13., 14., 15.]])
>>> np.argmax(a)
>>> np.argmax(a, axis=0)
array([1., 1., 1.])
>>> np.argmax(a, axis=1)
array([2., 2.])
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0., 5., 2., 3., 4., 5.])
>>> np.argmax(b)  # Only the first occurrence is returned.

Specify out ndarray:

>>> a = np.arange(6).reshape(2,3) + 10
>>> b = np.zeros((2,))
>>> np.argmax(a, axis=1, out=b)
array([2., 2.])
>>> b
array([2., 2.])