NDArray API

Arithmetic Operations

In the following example y can be a Real value or another NDArray.

API Example
+ x .+ y Elementwise summation
- x .- y Elementwise minus
* x .* y Elementwise multiplication
/ x ./ y Elementwise division
^ x .^ y Elementwise power
% x .% y Elementwise modulo

Trigonometric Functions

API Example
sin sin.(x) Elementwise sine
cos cos.(x) Elementwise cosine
tan tan.(x) Elementwise tangent
asin asin.(x) Elementwise inverse sine
acos acos.(x) Elementwise inverse cosine
atan atan.(x) Elementwise inverse tangent

Hyperbolic Functions

API Example
sinh sinh.(x) Elementwise hyperbolic sine
cosh cosh.(x) Elementwise hyperbolic cosine
tanh tanh.(x) Elementwise hyperbolic tangent
asinh asinh.(x) Elementwise inverse hyperbolic sine
acosh acosh.(x) Elementwise inverse hyperbolic cosine
atanh atanh.(x) Elementwise inverse hyperbolic tangent

Activation Functions

API Example
σ σ.(x) Sigmoid function
sigmoid sigmoid.(x) Sigmoid function
relu relu.(x) ReLU function
softmax softmax.(x) Softmax function
log_softmax log_softmax.(x) Softmax followed by log

Reference

# MXNet.mx.log_softmaxFunction.

log_softmax.(x::NDArray, [dim = ndims(x)])

Computes the log softmax of the input. This is equivalent to computing softmax followed by log.

julia> x 2×3 mx.NDArray{Float64,2} @ CPU0: 1.0 2.0 0.1 0.1 2.0 1.0

julia> mx.log_softmax.(x) 2×3 mx.NDArray{Float64,2} @ CPU0: -1.41703 -0.41703 -2.31703 -2.31703 -0.41703 -1.41703

source

# MXNet.mx.reluFunction.

relu.(x::NDArray)

Computes rectified linear.

source

# MXNet.mx.softmaxFunction.

softmax.(x::NDArray, [dim = ndims(x)])

Applies the softmax function.

The resulting array contains elements in the range (0, 1) and the elements along the given axis sum up to 1.

source

# MXNet.mx.σFunction.

σ.(x::NDArray)
sigmoid.(x::NDArray)

Computes sigmoid of x element-wise.

The storage type of sigmoid output is always dense.

source

# Base.catMethod.

cat(xs::NDArray...; dims)

Concate the NDArrays which have the same element type along the dims. Building a diagonal matrix is not supported yet.

source

# MXNet.mx.@inplaceMacro.

@inplace

Julia does not support re-definiton of += operator (like __iadd__ in python), When one write a += b, it gets translated to a = a+b. a+b will allocate new memory for the results, and the newly allocated NDArray object is then assigned back to a, while the original contents in a is discarded. This is very inefficient when we want to do inplace update.

This macro is a simple utility to implement this behavior. Write

  @mx.inplace a += b

will translate into

  mx.add_to!(a, b)

which will do inplace adding of the contents of b into a.

source

# Base.Broadcast.broadcast_axesMethod.

broadcast_axis(x::NDArray, dim, size)
broadcast_axes(x::NDArray, dim, size)

Broadcasts the input array over particular axis(axes). Parameter dim and size could be a scalar, a Tuple or an Array.

broadcast_axes is just an alias.

julia> x
1×2×1 mx.NDArray{Int64,3} @ CPU0:
[:, :, 1] =
 1  2

julia> mx.broadcast_axis(x, 1, 2)
2×2×1 mx.NDArray{Int64,3} @ CPU0:
[:, :, 1] =
 1  2
 1  2

julia> mx.broadcast_axis(x, 3, 2)
1×2×2 mx.NDArray{Int64,3} @ CPU0:
[:, :, 1] =
 1  2

[:, :, 2] =
 1  2

Defined in src/operator/tensor/broadcast_reduce_op_value.cc:L92

source

# Base.cosFunction.

cos.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L90`

source

# Base.coshFunction.

cosh.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L409`

source

# Base.reshapeMethod.

Base.reshape(x::NDArray, dim; reverse=false)

Defined in src/operator/tensor/matrix_op.cc:L174

source

# Base.reshapeMethod.

Base.reshape(x::NDArray, dim...; reverse=false)

Defined in src/operator/tensor/matrix_op.cc:L174

source

# Base.sinFunction.

sin.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L47`

source

# Base.sinhFunction.

sinh.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L371`

source

# Base.tanFunction.

tan.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L140`

source

# Base.tanhFunction.

tanh.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L451`

source

# MXNet.mx.broadcast_axisMethod.

broadcast_axis(x::NDArray, dim, size)
broadcast_axes(x::NDArray, dim, size)

Broadcasts the input array over particular axis(axes). Parameter dim and size could be a scalar, a Tuple or an Array.

broadcast_axes is just an alias.

julia> x
1×2×1 mx.NDArray{Int64,3} @ CPU0:
[:, :, 1] =
 1  2

julia> mx.broadcast_axis(x, 1, 2)
2×2×1 mx.NDArray{Int64,3} @ CPU0:
[:, :, 1] =
 1  2
 1  2

julia> mx.broadcast_axis(x, 3, 2)
1×2×2 mx.NDArray{Int64,3} @ CPU0:
[:, :, 1] =
 1  2

[:, :, 2] =
 1  2

Defined in src/operator/tensor/broadcast_reduce_op_value.cc:L92

source

# MXNet.mx.broadcast_toMethod.

broadcast_to(x::NDArray, dims)
broadcast_to(x::NDArray, dims...)

Broadcasts the input array to a new shape.

In the case of broacasting doesn't work out of box, you can expand the NDArray first.

julia> x = mx.ones(2, 3, 4);

julia> y = mx.ones(1, 1, 4);

julia> x .+ mx.broadcast_to(y, 2, 3, 4)
2×3×4 mx.NDArray{Float32,3} @ CPU0:
[:, :, 1] =
 2.0  2.0  2.0
 2.0  2.0  2.0

[:, :, 2] =
 2.0  2.0  2.0
 2.0  2.0  2.0

[:, :, 3] =
 2.0  2.0  2.0
 2.0  2.0  2.0

[:, :, 4] =
 2.0  2.0  2.0
 2.0  2.0  2.0

Defined in src/operator/tensor/broadcast_reduce_op_value.cc:L116

source

# MXNet.mx.broadcast_toMethod.

broadcast_to(x::NDArray, dims)
broadcast_to(x::NDArray, dims...)

Broadcasts the input array to a new shape.

In the case of broacasting doesn't work out of box, you can expand the NDArray first.

julia> x = mx.ones(2, 3, 4);

julia> y = mx.ones(1, 1, 4);

julia> x .+ mx.broadcast_to(y, 2, 3, 4)
2×3×4 mx.NDArray{Float32,3} @ CPU0:
[:, :, 1] =
 2.0  2.0  2.0
 2.0  2.0  2.0

[:, :, 2] =
 2.0  2.0  2.0
 2.0  2.0  2.0

[:, :, 3] =
 2.0  2.0  2.0
 2.0  2.0  2.0

[:, :, 4] =
 2.0  2.0  2.0
 2.0  2.0  2.0

Defined in src/operator/tensor/broadcast_reduce_op_value.cc:L116

source

# MXNet.mx.expand_dimsMethod.

expand_dims(x::NDArray, dim)

Insert a new axis into dim.

julia> x
4 mx.NDArray{Float64,1} @ CPU0:
 1.0
 2.0
 3.0
 4.0

julia> mx.expand_dims(x, 1)
1×4 mx.NDArray{Float64,2} @ CPU0:
 1.0  2.0  3.0  4.0

julia> mx.expand_dims(x, 2)
4×1 mx.NDArray{Float64,2} @ CPU0:
 1.0
 2.0
 3.0
 4.0

Defined in src/operator/tensor/matrix_op.cc:L394

source

# MXNet.mx.NDArrayType.

NDArray{T,N}

Wrapper of the NDArray type in libmxnet. This is the basic building block of tensor-based computation.

Note

since C/C++ use row-major ordering for arrays while Julia follows a column-major ordering. To keep things consistent, we keep the underlying data in their original layout, but use language-native convention when we talk about shapes. For example, a mini-batch of 100 MNIST images is a tensor of C/C++/Python shape (100,1,28,28), while in Julia, the same piece of memory have shape (28,28,1,100).

source

# Base.:%Method.

.%(x::NDArray, y::NDArray)
.%(x::NDArray, y::Real)
.%(x::Real, y::NDArray)

Elementwise modulo for NDArray.

source

# Base.:*Method.

.*(x, y)

Elementwise multiplication for NDArray.

source

# Base.:*Method.

*(A::NDArray, B::NDArray)

Matrix/tensor multiplication.

source

# Base.:+Method.

+(args...)
.+(args...)

Summation. Multiple arguments of either scalar or NDArray could be added together. Note at least the first or second argument needs to be an NDArray to avoid ambiguity of built-in summation.

source

# Base.:-Method.

-(x::NDArray)
-(x, y)
.-(x, y)

Subtraction x - y, of scalar types or NDArray. Or create the negative of x.

source

# Base.:/Method.

./(x::NDArray, y::NDArray)
./(x::NDArray, y::Real)
./(x::Real, y::NDArray)
  • Elementwise dividing an NDArray by a scalar or another NDArray

of the same shape.

  • Elementwise divide a scalar by an NDArray.
  • Matrix division (solving linear systems) is not implemented yet.

source

# Base.Math.clamp!Method.

clamp!(x::NDArray, lo, hi)

See also clamp.

source

# Base.Math.clampMethod.

clamp(x::NDArray, lo, hi)

Clamps (limits) the values in NDArray. Given an interval, values outside the interval are clipped to the interval edges. Clamping x between low lo and high hi would be:

clamp(x, lo, hi) = max(min(x, lo), hi))

The storage type of clip output depends on storage types of inputs and the lo, hi parameter values:

  • clamp(default) -> default
  • clamp(rowsparse, lo <= 0, hi >= 0) -> rowsparse
  • clamp(csr, lo <= 0, hi >= 0) -> csr
  • clamp(row_sparse, lo < 0, hi < 0) -> default
  • clamp(row_sparse, lo > 0, hi > 0) -> default
  • clamp(csr, lo < 0, hi < 0) -> csr
  • clamp(csr, lo > 0, hi > 0) -> csr

Examples

```jldoctest julia> x = NDArray(1:9);

julia> clamp(x, 2, 8)' 1×9 mx.NDArray{Int64,2} @ CPU0: 2 2 3 4 5 6 7 8 8

julia> clamp(x, 8, 2)' 1×9 NDArray{Int64,2} @ CPU0: 8 8 2 2 2 2 2 2 2 ```

source

# MXNet.mx.div_from!Method.

div_from!(dst::NDArray, arg::NDArrayOrReal)

Elementwise divide a scalar or an NDArray of the same shape from dst. Inplace updating.

source

# MXNet.mx.mod_from!Method.

mod_from!(x::NDArray, y::NDArray)
mod_from!(x::NDArray, y::Real)

Elementwise modulo for NDArray. Inplace updating.

source

# MXNet.mx.mul_to!Method.

mul_to!(dst::NDArray, arg::NDArrayOrReal)

Elementwise multiplication into dst of either a scalar or an NDArray of the same shape. Inplace updating.

source

# MXNet.mx.rdiv_from!Method.

rdiv_from!(x:: Real, y::NDArray)

Elementwise divide a scalar by an NDArray. Inplace updating.

source

# MXNet.mx.rmod_from!Method.

rmod_from!(y::Real, x::NDArray)

Elementwise modulo for NDArray. Inplace updating.

source

# MXNet.mx.sub_from!Method.

sub_from!(dst::NDArray, args::NDArrayOrReal...)

Subtract a bunch of arguments from dst. Inplace updating.

source

# Base.convertMethod.

convert(::Type{Array{<:Real}}, x::NDArray)

Convert an NDArray into a Julia Array of specific type. Data will be copied.

source

# Base.copyFunction.

copy(arr :: NDArray)
copy(arr :: NDArray, ctx :: Context)
copy(arr :: Array, ctx :: Context)

Create a copy of an array. When no Context is given, create a Julia Array. Otherwise, create an NDArray on the specified context.

source

# Base.copy!Method.

copy!(dst::Union{NDArray, Array}, src::Union{NDArray, Array})

Copy contents of src into dst.

source

# Base.deepcopyMethod.

deepcopy(arr::NDArray)

Get a deep copy of the data blob in the form of an NDArray of default storage type. This function blocks. Do not use it in performance critical code.

source

# Base.eltypeMethod.

eltype(x::NDArray)

Get the element type of an NDArray.

source

# Base.fill!Method.

fill!(arr::NDArray, x)

Create an NDArray filled with the value x, like Base.fill!.

source

# Base.getindexMethod.

getindex(arr::NDArray, idx)

Shortcut for slice. A typical use is to write

  arr[:] += 5

which translates into

  arr[:] = arr[:] + 5

which furthur translates into

  setindex!(getindex(arr, Colon()), 5, Colon())

Note

The behavior is quite different from indexing into Julia's Array. For example, arr[2:5] create a copy of the sub-array for Julia Array, while for NDArray, this is a slice that shares the memory.

source

# Base.getindexMethod.

Shortcut for slice. NOTE the behavior for Julia's built-in index slicing is to create a copy of the sub-array, while here we simply call slice, which shares the underlying memory.

source

# Base.hcatMethod.

hcat(x::NDArray...)

source

# Base.lengthMethod.

length(x::NDArray)

Get the number of elements in an NDArray.

source

# Base.ndimsMethod.

ndims(x::NDArray)

Get the number of dimensions of an NDArray. Is equivalent to length(size(arr)).

source

# Base.setindex!Method.

setindex!(arr::NDArray, val, idx)

Assign values to an NDArray. The following scenarios are supported

  • single value assignment via linear indexing: arr[42] = 24
  • arr[:] = val: whole array assignment, val could be a scalar or an array (Julia Array or NDArray) of the same shape.
  • arr[start:stop] = val: assignment to a slice, val could be a scalar or an array of the same shape to the slice. See also slice.

source

# Base.similarMethod.

similar(x::NDArray; writable, ctx)

Create an NDArray with similar shape, data type, and context with the given one. Note that the returned NDArray is uninitialized.

source

# Base.sizeMethod.

size(x::NDArray)
size(x::NDArray, dims)

Get the shape of an NDArray. The shape is in Julia's column-major convention. See also the notes on NDArray shapes NDArray.

source

# Base.vcatMethod.

vcat(x::NDArray...)

source

# MXNet.mx.add_to!Method.

add_to!(dst::NDArray, args::NDArrayOrReal...)

Add a bunch of arguments into dst. Inplace updating.

source

# MXNet.mx.fillMethod.

fill(x, dims, ctx = current_context())
fill(x, dims...)

Create an NDArray filled with the value x, like Base.fill.

source

# MXNet.mx.onesMethod.

ones([DType], dims, ctx::Context = current_context())
ones([DType], dims...)
ones(x::NDArray)

Create an NDArray with specific shape & type, and initialize with 1.

source

# MXNet.mx.sliceMethod.

slice(arr :: NDArray, start:stop)

Create a view into a sub-slice of an NDArray. Note only slicing at the slowest changing dimension is supported. In Julia's column-major perspective, this is the last dimension. For example, given an NDArray of shape (2,3,4), slice(array, 2:3) will create a NDArray of shape (2,3,2), sharing the data with the original array. This operation is used in data parallelization to split mini-batch into sub-batches for different devices.

source

# MXNet.mx.zerosMethod.

zeros([DType], dims, ctx::Context = current_context())
zeros([DType], dims...)
zeros(x::NDArray)

Create zero-ed NDArray with specific shape and type.

source

# MXNet.mx.@nd_as_jlMacro.

Manipulating as Julia Arrays

@nd_as_jl(captures..., statement)

A convenient macro that allows to operate NDArray as Julia Arrays. For example,

  x = mx.zeros(3,4)
  y = mx.ones(3,4)
  z = mx.zeros((3,4), mx.gpu())

  @mx.nd_as_jl ro=(x,y) rw=z begin
    # now x, y, z are just ordinary Julia Arrays
    z[:,1] = y[:,2]
    z[:,2] = 5
  end

Under the hood, the macro convert all the declared captures from NDArray into Julia Arrays, by using try_get_shared. And automatically commit the modifications back into the NDArray that is declared as rw. This is useful for fast prototyping and when implement non-critical computations, such as AbstractEvalMetric.

Note

  • Multiple rw and / or ro capture declaration could be made.
  • The macro does not check to make sure that ro captures are not modified. If the original NDArray lives in CPU memory, then it is very likely the corresponding Julia Array shares data with the NDArray, so modifying the Julia Array will also modify the underlying NDArray.
  • More importantly, since the NDArray is asynchronized, we will wait for writing for rw variables but wait only for reading in ro variables. If we write into those ro variables, and if the memory is shared, racing condition might happen, and the behavior is undefined.
  • When an NDArray is declared to be captured as rw, its contents is always sync back in the end.
  • The execution results of the expanded macro is always nothing.
  • The statements are wrapped in a let, thus locally introduced new variables will not be available after the statements. So you will need to declare the variables before calling the macro if needed.

source

# Base.Iterators.FlattenMethod.

Flatten(data)

Flattens the input array into a 2-D array by collapsing the higher dimensions. .. note:: Flatten is deprecated. Use flatten instead. For an input array with shape $(d1, d2, ..., dk)$, flatten operation reshapes the input array into an output array of shape $(d1, d2...dk)$. Note that the behavior of this function is different from numpy.ndarray.flatten, which behaves similar to mxnet.ndarray.reshape((-1,)). Example:: x = [[ [1,2,3], [4,5,6], [7,8,9] ], [ [1,2,3], [4,5,6], [7,8,9] ]], flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]]

Defined in src/operator/tensor/matrix_op.cc:L249

Arguments

  • data::NDArray-or-SymbolicNode: Input array.

source

# Base.Math.cbrtMethod.

cbrt(data)

Returns element-wise cube-root value of the input.

.. math:: cbrt(x) = \sqrt[3]{x}

Example::

cbrt([1, 8, -125]) = [1, 2, -5]

The storage type of $cbrt$ output depends upon the input storage type:

  • cbrt(default) = default
  • cbrt(rowsparse) = rowsparse
  • cbrt(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L270

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base._divMethod.

_div(lhs, rhs)

div is an alias of elemwisediv.

Divides arguments element-wise.

The storage type of $elemwise_div$ output is always dense

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# Base._linspaceMethod.

_linspace(start, stop, step, repeat, infer_range, ctx, dtype)

Return evenly spaced numbers over a specified interval. Similar to Numpy

Arguments

  • start::double, required: Start of interval. The interval includes this value. The default start value is 0.
  • stop::double or None, optional, default=None: End of interval. The interval does not include this value, except in some cases where step is not an integer and floating point round-off affects the length of out.
  • step::double, optional, default=1: Spacing between values.
  • repeat::int, optional, default='1': The repeating time of all elements. E.g repeat=3, the element a will be repeated three times –> a, a, a.
  • infer_range::boolean, optional, default=0: When set to True, infer the stop position from the start, step, repeat, and output tensor size.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# Base._maximumMethod.

_maximum(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# Base._minimumMethod.

_minimum(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# Base._subMethod.

_sub(lhs, rhs)

sub is an alias of elemwisesub.

Subtracts arguments element-wise.

The storage type of $elemwise_sub$ output depends on storage types of inputs

  • elemwisesub(rowsparse, rowsparse) = rowsparse
  • elemwise_sub(csr, csr) = csr
  • elemwise_sub(default, csr) = default
  • elemwise_sub(csr, default) = default
  • elemwise_sub(default, rsp) = default
  • elemwise_sub(rsp, default) = default
  • otherwise, $elemwise_sub$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# Base.absMethod.

abs(data)

Returns element-wise absolute value of the input.

Example::

abs([-2, 0, 3]) = [2, 0, 3]

The storage type of $abs$ output depends upon the input storage type:

  • abs(default) = default
  • abs(rowsparse) = rowsparse
  • abs(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L720

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.ceilMethod.

ceil(data)

Returns element-wise ceiling of the input.

The ceil of the scalar x is the smallest integer i, such that i >= x.

Example::

ceil([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-2., -1., 2., 2., 3.]

The storage type of $ceil$ output depends upon the input storage type:

  • ceil(default) = default
  • ceil(rowsparse) = rowsparse
  • ceil(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L817

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.cumsumMethod.

cumsum(a, axis, dtype)

cumsum is an alias of npcumsum.

Return the cumulative sum of the elements along a given axis.

Defined in src/operator/numpy/np_cumsum.cc:L70

Arguments

  • a::NDArray-or-SymbolicNode: Input ndarray
  • axis::int or None, optional, default='None': Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
  • dtype::{None, 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.

source

# Base.expMethod.

exp(data)

Returns element-wise exponential value of the input.

.. math:: exp(x) = e^x \approx 2.718^x

Example::

exp([0, 1, 2]) = [1., 2.71828175, 7.38905621]

The storage type of $exp$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L64

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.expm1Method.

expm1(data)

Returns $exp(x) - 1$ computed element-wise on the input.

This function provides greater precision than $exp(x) - 1$ for small values of $x$.

The storage type of $expm1$ output depends upon the input storage type:

  • expm1(default) = default
  • expm1(rowsparse) = rowsparse
  • expm1(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L244

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.floorMethod.

floor(data)

Returns element-wise floor of the input.

The floor of the scalar x is the largest integer i, such that i <= x.

Example::

floor([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-3., -2., 1., 1., 2.]

The storage type of $floor$ output depends upon the input storage type:

  • floor(default) = default
  • floor(rowsparse) = rowsparse
  • floor(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L836

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.identityMethod.

identity(data)

identity is an alias of _copy.

Returns a copy of the input.

From:src/operator/tensor/elemwiseunaryop_basic.cc:244

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.logMethod.

log(data)

Returns element-wise Natural logarithmic value of the input.

The natural logarithm is logarithm in base e, so that $log(exp(x)) = x$

The storage type of $log$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L77

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.log10Method.

log10(data)

Returns element-wise Base-10 logarithmic value of the input.

$10**log10(x) = x$

The storage type of $log10$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L94

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.log1pMethod.

log1p(data)

Returns element-wise $log(1 + x)$ value of the input.

This function is more accurate than $log(1 + x)$ for small $x$ so that :math:1+x\approx 1

The storage type of $log1p$ output depends upon the input storage type:

  • log1p(default) = default
  • log1p(rowsparse) = rowsparse
  • log1p(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L199

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.log2Method.

log2(data)

Returns element-wise Base-2 logarithmic value of the input.

$2**log2(x) = x$

The storage type of $log2$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L106

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.repeatMethod.

repeat(data, repeats, axis)

Repeats elements of an array. By default, $repeat$ flattens the input array into 1-D and then repeats the elements:: x = [[ 1, 2], [ 3, 4]] repeat(x, repeats=2) = [ 1., 1., 2., 2., 3., 3., 4., 4.] The parameter $axis$ specifies the axis along which to perform repeat:: repeat(x, repeats=2, axis=1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]] repeat(x, repeats=2, axis=0) = [[ 1., 2.], [ 1., 2.], [ 3., 4.], [ 3., 4.]] repeat(x, repeats=2, axis=-1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]]

Defined in src/operator/tensor/matrix_op.cc:L743

Arguments

  • data::NDArray-or-SymbolicNode: Input data array
  • repeats::int, required: The number of repetitions for each element.
  • axis::int or None, optional, default='None': The axis along which to repeat values. The negative numbers are interpreted counting from the backward. By default, use the flattened input array, and return a flat output array.

source

# Base.reverseMethod.

reverse(data, axis)

Reverses the order of elements along given axis while preserving array shape. Note: reverse and flip are equivalent. We use reverse in the following examples. Examples:: x = [[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.]] reverse(x, axis=0) = [[ 5., 6., 7., 8., 9.], [ 0., 1., 2., 3., 4.]] reverse(x, axis=1) = [[ 4., 3., 2., 1., 0.], [ 9., 8., 7., 6., 5.]]

Defined in src/operator/tensor/matrix_op.cc:L831

Arguments

  • data::NDArray-or-SymbolicNode: Input data array
  • axis::Shape(tuple), required: The axis which to reverse elements.

source

# Base.roundMethod.

round(data)

Returns element-wise rounded value to the nearest integer of the input.

Example::

round([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 2., -2., 2., 2.]

The storage type of $round$ output depends upon the input storage type:

  • round(default) = default
  • round(rowsparse) = rowsparse
  • round(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L777

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.signMethod.

sign(data)

Returns element-wise sign of the input.

Example::

sign([-2, 0, 3]) = [-1, 0, 1]

The storage type of $sign$ output depends upon the input storage type:

  • sign(default) = default
  • sign(rowsparse) = rowsparse
  • sign(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L758

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.sortMethod.

sort(data, axis, is_ascend)

Returns a sorted copy of an input array along the given axis.

Examples::

x = [[ 1, 4], [ 3, 1]]

// sorts along the last axis sort(x) = [[ 1., 4.], [ 1., 3.]]

// flattens and then sorts sort(x, axis=None) = [ 1., 1., 3., 4.]

// sorts along the first axis sort(x, axis=0) = [[ 1., 1.], [ 3., 4.]]

// in a descend order sort(x, is_ascend=0) = [[ 4., 1.], [ 3., 1.]]

Defined in src/operator/tensor/ordering_op.cc:L132

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • axis::int or None, optional, default='-1': Axis along which to choose sort the input tensor. If not given, the flattened array is used. Default is -1.
  • is_ascend::boolean, optional, default=1: Whether to sort in ascending or descending order.

source

# Base.splitMethod.

split(data, num_outputs, axis, squeeze_axis)

split is an alias of SliceChannel.

Splits an array along a particular axis into multiple sub-arrays.

.. note:: $SliceChannel$ is deprecated. Use $split$ instead.

Note that num_outputs should evenly divide the length of the axis along which to split the array.

Example::

x = [[[ 1.] [ 2.]] [[ 3.] [ 4.]] [[ 5.] [ 6.]]] x.shape = (3, 2, 1)

y = split(x, axis=1, num_outputs=2) // a list of 2 arrays with shape (3, 1, 1) y = [[[ 1.]] [[ 3.]] [[ 5.]]]

   [[[ 2.]]
    [[ 4.]]
    [[ 6.]]]

y[0].shape = (3, 1, 1)

z = split(x, axis=0, num_outputs=3) // a list of 3 arrays with shape (1, 2, 1) z = [[[ 1.] [ 2.]]]

   [[[ 3.]
     [ 4.]]]

   [[[ 5.]
     [ 6.]]]

z[0].shape = (1, 2, 1)

squeeze_axis=1 removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $1$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to true only if $input.shape[axis] == num_outputs$.

Example::

z = split(x, axis=0, numoutputs=3, squeezeaxis=1) // a list of 3 arrays with shape (2, 1) z = [[ 1.] [ 2.]]

   [[ 3.]
    [ 4.]]

   [[ 5.]
    [ 6.]]

z[0].shape = (2 ,1 )

Defined in src/operator/slice_channel.cc:L106

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • num_outputs::int, required: Number of splits. Note that this should evenly divide the length of the axis.
  • axis::int, optional, default='1': Axis along which to split.
  • squeeze_axis::boolean, optional, default=0: If true, Removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $true$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to $true$ only if $input.shape[axis] == num_outputs$.

source

# Base.sqrtMethod.

sqrt(data)

Returns element-wise square-root value of the input.

.. math:: \textrm{sqrt}(x) = \sqrt{x}

Example::

sqrt([4, 9, 16]) = [2, 3, 4]

The storage type of $sqrt$ output depends upon the input storage type:

  • sqrt(default) = default
  • sqrt(rowsparse) = rowsparse
  • sqrt(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L170

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# Base.truncMethod.

trunc(data)

Return the element-wise truncated value of the input.

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is. In short, the fractional part of the signed number x is discarded.

Example::

trunc([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-2., -1., 1., 1., 2.]

The storage type of $trunc$ output depends upon the input storage type:

  • trunc(default) = default
  • trunc(rowsparse) = rowsparse
  • trunc(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L856

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# LinearAlgebra.diagMethod.

diag(data, k, axis1, axis2)

Extracts a diagonal or constructs a diagonal array.

$diag$'s behavior depends on the input array dimensions:

  • 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero.
  • N-D arrays: extracts the diagonals of the sub-arrays with axes specified by $axis1$ and $axis2$. The output shape would be decided by removing the axes numbered $axis1$ and $axis2$ from the input shape and appending to the result a new axis with the size of the diagonals in question.

    For example, when the input shape is (2, 3, 4, 5), $axis1$ and $axis2$ are 0 and 2 respectively and $k$ is 0, the resulting shape would be (3, 5, 2).

Examples::

x = [[1, 2, 3], [4, 5, 6]]

diag(x) = [1, 5]

diag(x, k=1) = [2, 6]

diag(x, k=-1) = [4]

x = [1, 2, 3]

diag(x) = [[1, 0, 0], [0, 2, 0], [0, 0, 3]]

diag(x, k=1) = [[0, 1, 0], [0, 0, 2], [0, 0, 0]]

diag(x, k=-1) = [[0, 0, 0], [1, 0, 0], [0, 2, 0]]

x = [[[1, 2], [3, 4]],

   [[5, 6],
    [7, 8]]]

diag(x) = [[1, 7], [2, 8]]

diag(x, k=1) = [[3], [4]]

diag(x, axis1=-2, axis2=-1) = [[1, 4], [5, 8]]

Defined in src/operator/tensor/diag_op.cc:L86

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • k::int, optional, default='0': Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal. If input has shape (S0 S1) k must be between -S0 and S1
  • axis1::int, optional, default='0': The first axis of the sub-arrays of interest. Ignored when the input is a 1-D array.
  • axis2::int, optional, default='1': The second axis of the sub-arrays of interest. Ignored when the input is a 1-D array.

source

# LinearAlgebra.normMethod.

norm(data, ord, axis, out_dtype, keepdims)

Computes the norm on an NDArray.

This operator computes the norm on an NDArray with the specified axis, depending on the value of the ord parameter. By default, it computes the L2 norm on the entire array. Currently only ord=2 supports sparse ndarrays.

Examples::

x = [[[1, 2], [3, 4]], [[2, 2], [5, 6]]]

norm(x, ord=2, axis=1) = [[3.1622777 4.472136 ] [5.3851647 6.3245554]]

norm(x, ord=1, axis=1) = [[4., 6.], [7., 8.]]

rsp = x.caststorage('rowsparse')

norm(rsp) = [5.47722578]

csr = x.cast_storage('csr')

norm(csr) = [5.47722578]

Defined in src/operator/tensor/broadcastreducenorm_value.cc:L88

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • ord::int, optional, default='2': Order of the norm. Currently ord=1 and ord=2 is supported.
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction. The default, axis=(), will compute over all elements into a scalar array with shape (1,). If axis is int, a reduction is performed on a particular axis. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed.
  • out_dtype::{None, 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The data type of the output.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axis is left in the result as dimension with size one.

source

# MXNet.mx.ActivationMethod.

Activation(data, act_type)

Applies an activation function element-wise to the input.

The following activation functions are supported:

  • relu: Rectified Linear Unit, :math:y = max(x, 0)
  • sigmoid: :math:y = \frac{1}{1 + exp(-x)}
  • tanh: Hyperbolic tangent, :math:y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}
  • softrelu: Soft ReLU, or SoftPlus, :math:y = log(1 + exp(x))
  • softsign: :math:y = \frac{x}{1 + abs(x)}

Defined in src/operator/nn/activation.cc:L164

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • act_type::{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required: Activation function to be applied.

source

# MXNet.mx.BatchNormMethod.

BatchNorm(data, gamma, beta, moving_mean, moving_var, eps, momentum, fix_gamma, use_global_stats, output_mean_var, axis, cudnn_off, min_calib_range, max_calib_range)

Batch normalization.

Normalizes a data batch by mean and variance, and applies a scale $gamma$ as well as offset $beta$.

Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis:

.. math::

data_mean[i] = mean(data[:,i,:,...]) \ data_var[i] = var(data[:,i,:,...])

Then compute the normalized output, which has the same shape as input, as following:

.. math::

out[:,i,:,...] = \frac{data[:,i,:,...] - data_mean[i]}{\sqrt{data_var[i]+\epsilon}} * gamma[i] + beta[i]

Both mean and var returns a scalar by treating the input as a vector.

Assume the input has size k on axis 1, then both $gamma$ and $beta$ have shape (k,). If $output_mean_var$ is set to be true, then outputs both $data_mean$ and the inverse of $data_var$, which are needed for the backward pass. Note that gradient of these two outputs are blocked.

Besides the inputs and the outputs, this operator accepts two auxiliary states, $moving_mean$ and $moving_var$, which are k-length vectors. They are global statistics for the whole dataset, which are updated by::

movingmean = movingmean * momentum + datamean * (1 - momentum) movingvar = movingvar * momentum + datavar * (1 - momentum)

If $use_global_stats$ is set to be true, then $moving_mean$ and $moving_var$ are used instead of $data_mean$ and $data_var$ to compute the output. It is often used during inference.

The parameter $axis$ specifies which axis of the input shape denotes the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel axis to be the last item in the input shape.

Both $gamma$ and $beta$ are learnable parameters. But if $fix_gamma$ is true, then set $gamma$ to 1 and its gradient to 0.

.. Note:: When $fix_gamma$ is set to True, no sparse support is provided. If $fix_gamma is$ set to False, the sparse tensors will fallback.

Defined in src/operator/nn/batch_norm.cc:L608

Arguments

  • data::NDArray-or-SymbolicNode: Input data to batch normalization
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • moving_mean::NDArray-or-SymbolicNode: running mean of input
  • moving_var::NDArray-or-SymbolicNode: running variance of input
  • eps::double, optional, default=0.0010000000474974513: Epsilon to prevent div 0. Must be no less than CUDNNBNMIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5)
  • momentum::float, optional, default=0.899999976: Momentum for moving average
  • fix_gamma::boolean, optional, default=1: Fix gamma while training
  • use_global_stats::boolean, optional, default=0: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator.
  • output_mean_var::boolean, optional, default=0: Output the mean and inverse std
  • axis::int, optional, default='1': Specify which shape axis the channel is specified
  • cudnn_off::boolean, optional, default=0: Do not select CUDNN operator, if available
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.

source

# MXNet.mx.BatchNorm_v1Method.

BatchNorm_v1(data, gamma, beta, eps, momentum, fix_gamma, use_global_stats, output_mean_var)

Batch normalization.

This operator is DEPRECATED. Perform BatchNorm on the input.

Normalizes a data batch by mean and variance, and applies a scale $gamma$ as well as offset $beta$.

Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis:

.. math::

data_mean[i] = mean(data[:,i,:,...]) \ data_var[i] = var(data[:,i,:,...])

Then compute the normalized output, which has the same shape as input, as following:

.. math::

out[:,i,:,...] = \frac{data[:,i,:,...] - data_mean[i]}{\sqrt{data_var[i]+\epsilon}} * gamma[i] + beta[i]

Both mean and var returns a scalar by treating the input as a vector.

Assume the input has size k on axis 1, then both $gamma$ and $beta$ have shape (k,). If $output_mean_var$ is set to be true, then outputs both $data_mean$ and $data_var$ as well, which are needed for the backward pass.

Besides the inputs and the outputs, this operator accepts two auxiliary states, $moving_mean$ and $moving_var$, which are k-length vectors. They are global statistics for the whole dataset, which are updated by::

movingmean = movingmean * momentum + datamean * (1 - momentum) movingvar = movingvar * momentum + datavar * (1 - momentum)

If $use_global_stats$ is set to be true, then $moving_mean$ and $moving_var$ are used instead of $data_mean$ and $data_var$ to compute the output. It is often used during inference.

Both $gamma$ and $beta$ are learnable parameters. But if $fix_gamma$ is true, then set $gamma$ to 1 and its gradient to 0.

There's no sparse support for this operator, and it will exhibit problematic behavior if used with sparse tensors.

Defined in src/operator/batchnormv1.cc:L94

Arguments

  • data::NDArray-or-SymbolicNode: Input data to batch normalization
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • eps::float, optional, default=0.00100000005: Epsilon to prevent div 0
  • momentum::float, optional, default=0.899999976: Momentum for moving average
  • fix_gamma::boolean, optional, default=1: Fix gamma while training
  • use_global_stats::boolean, optional, default=0: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator.
  • output_mean_var::boolean, optional, default=0: Output All,normal mean and var

source

# MXNet.mx.BilinearSamplerMethod.

BilinearSampler(data, grid, cudnn_off)

Applies bilinear sampling to input feature map.

Bilinear Sampling is the key of [NIPS2015] \"Spatial Transformer Networks\". The usage of the operator is very similar to remap function in OpenCV, except that the operator has the backward pass.

Given :math:data and :math:grid, then the output is computed by

.. math:: x{src} = grid[batch, 0, y, x{dst}] \ y = grid[batch, 1, y{dst}, x] \ output[batch, channel, y{dst}, x] = G(data[batch, channel, y{src}, x)

:math:x_{dst}, :math:y_{dst} enumerate all spatial locations in :math:output, and :math:G() denotes the bilinear interpolation kernel. The out-boundary points will be padded with zeros.The shape of the output will be (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).

The operator assumes that :math:data has 'NCHW' layout and :math:grid has been normalized to [-1, 1].

BilinearSampler often cooperates with GridGenerator which generates sampling grids for BilinearSampler. GridGenerator supports two kinds of transformation: $affine$ and $warp$. If users want to design a CustomOp to manipulate :math:grid, please firstly refer to the code of GridGenerator.

Example 1::

Zoom out data two times

data = array([[[[1, 4, 3, 6], [1, 8, 8, 9], [0, 4, 1, 5], [1, 0, 1, 3]]]])

affine_matrix = array([[2, 0, 0], [0, 2, 0]])

affinematrix = reshape(affinematrix, shape=(1, 6))

grid = GridGenerator(data=affinematrix, transformtype='affine', target_shape=(4, 4))

out = BilinearSampler(data, grid)

out [[[[ 0, 0, 0, 0], [ 0, 3.5, 6.5, 0], [ 0, 1.25, 2.5, 0], [ 0, 0, 0, 0]]]

Example 2::

shift data horizontally by -1 pixel

data = array([[[[1, 4, 3, 6], [1, 8, 8, 9], [0, 4, 1, 5], [1, 0, 1, 3]]]])

warp_maxtrix = array([[[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]])

grid = GridGenerator(data=warpmatrix, transformtype='warp') out = BilinearSampler(data, grid)

out [[[[ 4, 3, 6, 0], [ 8, 8, 9, 0], [ 4, 1, 5, 0], [ 0, 1, 3, 0]]]

Defined in src/operator/bilinear_sampler.cc:L255

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the BilinearsamplerOp.
  • grid::NDArray-or-SymbolicNode: Input grid to the BilinearsamplerOp.grid has two channels: xsrc, ysrc
  • cudnn_off::boolean or None, optional, default=None: whether to turn cudnn off

source

# MXNet.mx.BlockGradMethod.

BlockGrad(data)

Stops gradient computation.

Stops the accumulated gradient of the inputs from flowing through this operator in the backward direction. In other words, this operator prevents the contribution of its inputs to be taken into account for computing gradients.

Example::

v1 = [1, 2] v2 = [0, 1] a = Variable('a') b = Variable('b') bstopgrad = stopgradient(3 * b) loss = MakeLoss(bstop_grad + a)

executor = loss.simplebind(ctx=cpu(), a=(1,2), b=(1,2)) executor.forward(istrain=True, a=v1, b=v2) executor.outputs [ 1. 5.]

executor.backward() executor.grad_arrays [ 0. 0.] [ 1. 1.]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L325

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.CTCLossMethod.

CTCLoss(data, label, data_lengths, label_lengths, use_data_lengths, use_label_lengths, blank_label)

Connectionist Temporal Classification Loss.

.. note:: The existing alias $contrib_CTCLoss$ is deprecated.

The shapes of the inputs and outputs:

  • data: (sequence_length, batch_size, alphabet_size)
  • label: (batch_size, label_sequence_length)
  • out: (batch_size)

The data tensor consists of sequences of activation vectors (without applying softmax), with i-th channel in the last dimension corresponding to i-th label for i between 0 and alphabet*size-1 (i.e always 0-indexed). Alphabet size should include one additional value reserved for blank label. When blank*labelis"first", the0-th channel is be reserved for activation of blank label, or otherwise if it is "last",(alphabet_size-1)-th channel should be reserved for blank label.

$label$ is an index matrix of integers. When blank_label is $"first"$, the value 0 is then reserved for blank label, and should not be passed in this matrix. Otherwise, when blank_label is $"last"$, the value (alphabet_size-1) is reserved for blank label.

If a sequence of labels is shorter than labelsequencelength, use the special padding value at the end of the sequence to conform it to the correct length. The padding value is 0 when blank_label is $"first"$, and -1 otherwise.

For example, suppose the vocabulary is [a, b, c], and in one batch we have three sequences 'ba', 'cbb', and 'abac'. When blank_label is $"first"$, we can index the labels as {'a': 1, 'b': 2, 'c': 3}, and we reserve the 0-th channel for blank label in data tensor. The resulting label tensor should be padded to be::

[[2, 1, 0, 0], [3, 2, 2, 0], [1, 2, 1, 3]]

When blank_label is $"last"$, we can index the labels as {'a': 0, 'b': 1, 'c': 2}, and we reserve the channel index 3 for blank label in data tensor. The resulting label tensor should be padded to be::

[[1, 0, -1, -1], [2, 1, 1, -1], [0, 1, 0, 2]]

$out$ is a list of CTC loss values, one per example in the batch.

See Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks, A. Graves et al. for more information on the definition and the algorithm.

Defined in src/operator/nn/ctc_loss.cc:L100

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • label::NDArray-or-SymbolicNode: Ground-truth labels for the loss.
  • data_lengths::NDArray-or-SymbolicNode: Lengths of data for each of the samples. Only required when usedatalengths is true.
  • label_lengths::NDArray-or-SymbolicNode: Lengths of labels for each of the samples. Only required when uselabellengths is true.
  • use_data_lengths::boolean, optional, default=0: Whether the data lenghts are decided by data_lengths. If false, the lengths are equal to the max sequence length.
  • use_label_lengths::boolean, optional, default=0: Whether the label lenghts are decided by label_lengths, or derived from padding_mask. If false, the lengths are derived from the first occurrence of the value of padding_mask. The value of padding_mask is $0$ when first CTC label is reserved for blank, and $-1$ when last label is reserved for blank. See blank_label.
  • blank_label::{'first', 'last'},optional, default='first': Set the label that is reserved for blank label.If "first", 0-th label is reserved, and label values for tokens in the vocabulary are between $1$ and $alphabet_size-1$, and the padding mask is $-1$. If "last", last label value $alphabet_size-1$ is reserved for blank label instead, and label values for tokens in the vocabulary are between $0$ and $alphabet_size-2$, and the padding mask is $0$.

source

# MXNet.mx.CastMethod.

Cast(data, dtype)

Casts all elements of the input to a new type.

.. note:: $Cast$ is deprecated. Use $cast$ instead.

Example::

cast([0.9, 1.3], dtype='int32') = [0, 1] cast([1e20, 11.1], dtype='float16') = [inf, 11.09375] cast([300, 11.1, 10.9, -1, -3], dtype='uint8') = [44, 11, 10, 255, 253]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L664

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, required: Output data type.

source

# MXNet.mx.ConcatMethod.

Concat(data, num_args, dim)

Note: Concat takes variable number of positional inputs. So instead of calling as Concat([x, y, z], numargs=3), one should call via Concat(x, y, z), and numargs will be determined automatically.

Joins input arrays along a given axis.

.. note:: Concat is deprecated. Use concat instead.

The dimensions of the input arrays should be the same except the axis along which they will be concatenated. The dimension of the output array along the concatenated axis will be equal to the sum of the corresponding dimensions of the input arrays.

The storage type of $concat$ output depends on storage types of inputs

  • concat(csr, csr, ..., csr, dim=0) = csr
  • otherwise, $concat$ generates output with default storage

Example::

x = [[1,1],[2,2]] y = [[3,3],[4,4],[5,5]] z = [[6,6], [7,7],[8,8]]

concat(x,y,z,dim=0) = [[ 1., 1.], [ 2., 2.], [ 3., 3.], [ 4., 4.], [ 5., 5.], [ 6., 6.], [ 7., 7.], [ 8., 8.]]

Note that you cannot concat x,y,z along dimension 1 since dimension 0 is not the same for all the input arrays.

concat(y,z,dim=1) = [[ 3., 3., 6., 6.], [ 4., 4., 7., 7.], [ 5., 5., 8., 8.]]

Defined in src/operator/nn/concat.cc:L384

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx.ConvolutionMethod.

Convolution(data, weight, bias, kernel, stride, dilate, pad, num_filter, num_group, workspace, no_bias, cudnn_tune, cudnn_off, layout)

Compute N-D convolution on (N+2)-D input.

In the 2-D convolution, given input data with shape (batch_size, channel, height, width), the output is computed by

.. math::

out[n,i,:,:] = bias[i] + \sum_{j=0}^{channel} data[n,j,:,:] \star weight[i,j,:,:]

where :math:\star is the 2-D cross-correlation operator.

For general 2-D convolution, the shapes are

  • data: (batch_size, channel, height, width)
  • weight: (num_filter, channel, kernel[0], kernel[1])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, outheight, outwidth).

Define::

f(x,k,p,s,d) = floor((x+2p-d(k-1)-1)/s)+1

then we have::

outheight=f(height, kernel[0], pad[0], stride[0], dilate[0]) outwidth=f(width, kernel[1], pad[1], stride[1], dilate[1])

If $no_bias$ is set to be true, then the $bias$ term is ignored.

The default data $layout$ is NCHW, namely (batch_size, channel, height, width). We can choose other layouts such as NWC.

If $num_group$ is larger than 1, denoted by g, then split the input $data$ evenly into g parts along the channel axis, and also evenly split $weight$ along the first dimension. Next compute the convolution on the i-th part of the data with the i-th weight part. The output is obtained by concatenating all the g results.

1-D convolution does not have height dimension but only width in space.

  • data: (batch_size, channel, width)
  • weight: (num_filter, channel, kernel[0])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, out_width).

3-D convolution adds an additional depth dimension besides height and width. The shapes are

  • data: (batch_size, channel, depth, height, width)
  • weight: (num_filter, channel, kernel[0], kernel[1], kernel[2])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, outdepth, outheight, out_width).

Both $weight$ and $bias$ are learnable parameters.

There are other options to tune the performance.

  • cudnn_tune: enable this option leads to higher startup time but may give faster speed. Options are

    • off: no tuning
    • limited_workspace:run test and pick the fastest algorithm that doesn't exceed workspace limit.
    • fastest: pick the fastest algorithm and ignore workspace limit.
    • None (default): the behavior is determined by environment variable $MXNET_CUDNN_AUTOTUNE_DEFAULT$. 0 for off, 1 for limited workspace (default), 2 for fastest.
    • workspace: A large number leads to more (GPU) memory usage but may improve the performance.

Defined in src/operator/nn/convolution.cc:L475

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the ConvolutionOp.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • kernel::Shape(tuple), required: Convolution kernel size: (w,), (h, w) or (d, h, w)
  • stride::Shape(tuple), optional, default=[]: Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.
  • num_filter::int (non-negative), required: Convolution filter(channel) number
  • num_group::int (non-negative), optional, default=1: Number of group partitions.
  • workspace::long (non-negative), optional, default=1024: Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages. When CUDNN is not used, it determines the effective batch size of the convolution kernel. When CUDNN is used, it controls the maximum temporary storage used for tuning the best CUDNN kernel when limited_workspace strategy is used.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • cudnn_tune::{None, 'fastest', 'limited_workspace', 'off'},optional, default='None': Whether to pick convolution algo by running performance test.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn for this layer.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None': Set layout for input, output and weight. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.NHWC and NDHWC are only supported on GPU.

source

# MXNet.mx.Convolution_v1Method.

Convolution_v1(data, weight, bias, kernel, stride, dilate, pad, num_filter, num_group, workspace, no_bias, cudnn_tune, cudnn_off, layout)

This operator is DEPRECATED. Apply convolution to input then add a bias.

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the ConvolutionV1Op.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • kernel::Shape(tuple), required: convolution kernel size: (h, w) or (d, h, w)
  • stride::Shape(tuple), optional, default=[]: convolution stride: (h, w) or (d, h, w)
  • dilate::Shape(tuple), optional, default=[]: convolution dilate: (h, w) or (d, h, w)
  • pad::Shape(tuple), optional, default=[]: pad for convolution: (h, w) or (d, h, w)
  • num_filter::int (non-negative), required: convolution filter(channel) number
  • num_group::int (non-negative), optional, default=1: Number of group partitions. Equivalent to slicing input into num_group partitions, apply convolution on each, then concatenate the results
  • workspace::long (non-negative), optional, default=1024: Maximum temporary workspace allowed for convolution (MB).This parameter determines the effective batch size of the convolution kernel, which may be smaller than the given batch size. Also, the workspace will be automatically enlarged to make sure that we can run the kernel with batch_size=1
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • cudnn_tune::{None, 'fastest', 'limited_workspace', 'off'},optional, default='None': Whether to pick convolution algo by running performance test. Leads to higher startup time but may give faster speed. Options are: 'off': no tuning 'limitedworkspace': run test and pick the fastest algorithm that doesn't exceed workspace limit. 'fastest': pick the fastest algorithm and ignore workspace limit. If set to None (default), behavior is determined by environment variable MXNETCUDNNAUTOTUNEDEFAULT: 0 for off, 1 for limited workspace (default), 2 for fastest.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn for this layer.
  • layout::{None, 'NCDHW', 'NCHW', 'NDHWC', 'NHWC'},optional, default='None': Set layout for input, output and weight. Empty for default layout: NCHW for 2d and NCDHW for 3d.

source

# MXNet.mx.CorrelationMethod.

Correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)

Applies correlation to inputs.

The correlation layer performs multiplicative patch comparisons between two feature maps.

Given two multi-channel feature maps :math:f_{1}, f_{2}, with :math:w, :math:h, and :math:c being their width, height, and number of channels, the correlation layer lets the network compare each patch from :math:f_{1} with each patch from :math:f_{2}.

For now we consider only a single comparison of two patches. The 'correlation' of two patches centered at :math:x_{1} in the first map and :math:x_{2} in the second map is then defined as:

.. math::

c(x{1}, x) = \sum*{o \in [-k,k] \times [-k,k]}

for a square patch of size :math:K:=2k+1.

Note that the equation above is identical to one step of a convolution in neural networks, but instead of convolving data with a filter, it convolves data with other data. For this reason, it has no training weights.

Computing :math:c(x_{1}, x_{2}) involves :math:c * K^{2} multiplications. Comparing all patch combinations involves :math:w^{2}*h^{2} such computations.

Given a maximum displacement :math:d, for each location :math:x_{1} it computes correlations :math:c(x_{1}, x_{2}) only in a neighborhood of size :math:D:=2d+1, by limiting the range of :math:x_{2}. We use strides :math:s_{1}, s_{2}, to quantize :math:x_{1} globally and to quantize :math:x_{2} within the neighborhood centered around :math:x_{1}.

The final output is defined by the following expression:

.. math:: out[n, q, i, j] = c(x{i, j}, x)

where :math:i and :math:j enumerate spatial locations in :math:f_{1}, and :math:q denotes the :math:q^{th} neighborhood of :math:x_{i,j}.

Defined in src/operator/correlation.cc:L197

Arguments

  • data1::NDArray-or-SymbolicNode: Input data1 to the correlation.
  • data2::NDArray-or-SymbolicNode: Input data2 to the correlation.
  • kernel_size::int (non-negative), optional, default=1: kernel size for Correlation must be an odd number
  • max_displacement::int (non-negative), optional, default=1: Max displacement of Correlation
  • stride1::int (non-negative), optional, default=1: stride1 quantize data1 globally
  • stride2::int (non-negative), optional, default=1: stride2 quantize data2 within the neighborhood centered around data1
  • pad_size::int (non-negative), optional, default=0: pad for Correlation
  • is_multiply::boolean, optional, default=1: operation type is either multiplication or subduction

source

# MXNet.mx.CropMethod.

Crop(data, num_args, offset, h_w, center_crop)

Note: Crop takes variable number of positional inputs. So instead of calling as Crop([x, y, z], numargs=3), one should call via Crop(x, y, z), and numargs will be determined automatically.

.. note:: Crop is deprecated. Use slice instead.

Crop the 2nd and 3rd dim of input data, with the corresponding size of hw or with width and height of the second input symbol, i.e., with one input, we need hw to specify the crop height and width, otherwise the second input symbol's size will be used

Defined in src/operator/crop.cc:L49

Arguments

  • data::SymbolicNode or SymbolicNode[]: Tensor or List of Tensors, the second input will be used as crop_like shape reference
  • num_args::int, required: Number of inputs for crop, if equals one, then we will use the hwfor crop height and width, else if equals two, then we will use the heightand width of the second input symbol, we name croplike here
  • offset::Shape(tuple), optional, default=[0,0]: crop offset coordinate: (y, x)
  • h_w::Shape(tuple), optional, default=[0,0]: crop height and width: (h, w)
  • center_crop::boolean, optional, default=0: If set to true, then it will use be the centercrop,or it will crop using the shape of croplike

source

# MXNet.mx.CustomMethod.

Custom(data, op_type)

Apply a custom operator implemented in a frontend language (like Python).

Custom operators should override required methods like forward and backward. The custom operator must be registered before it can be used. Please check the tutorial here: https://mxnet.incubator.apache.org/api/faq/new_op

Defined in src/operator/custom/custom.cc:L546

Arguments

  • data::NDArray-or-SymbolicNode[]: Input data for the custom operator.
  • op_type::string: Name of the custom operator. This is the name that is passed to mx.operator.register to register the operator.

source

# MXNet.mx.DeconvolutionMethod.

Deconvolution(data, weight, bias, kernel, stride, dilate, pad, adj, target_shape, num_filter, num_group, workspace, no_bias, cudnn_tune, cudnn_off, layout)

Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the input tensor. This operation can be seen as the gradient of Convolution operation with respect to its input. Convolution usually reduces the size of the input. Transposed convolution works the other way, going from a smaller input to a larger output while preserving the connectivity pattern.

Arguments

  • data::NDArray-or-SymbolicNode: Input tensor to the deconvolution operation.
  • weight::NDArray-or-SymbolicNode: Weights representing the kernel.
  • bias::NDArray-or-SymbolicNode: Bias added to the result after the deconvolution operation.
  • kernel::Shape(tuple), required: Deconvolution kernel size: (w,), (h, w) or (d, h, w). This is same as the kernel size used for the corresponding convolution
  • stride::Shape(tuple), optional, default=[]: The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: The amount of implicit zero padding added during convolution for each dimension of the input: (w,), (h, w) or (d, h, w). $(kernel-1)/2$ is usually a good choice. If target_shape is set, pad will be ignored and a padding that will generate the target shape will be used. Defaults to no padding.
  • adj::Shape(tuple), optional, default=[]: Adjustment for output shape: (w,), (h, w) or (d, h, w). If target_shape is set, adj will be ignored and computed accordingly.
  • target_shape::Shape(tuple), optional, default=[]: Shape of the output tensor: (w,), (h, w) or (d, h, w).
  • num_filter::int (non-negative), required: Number of output filters.
  • num_group::int (non-negative), optional, default=1: Number of groups partition.
  • workspace::long (non-negative), optional, default=512: Maximum temporary workspace allowed (MB) in deconvolution.This parameter has two usages. When CUDNN is not used, it determines the effective batch size of the deconvolution kernel. When CUDNN is used, it controls the maximum temporary storage used for tuning the best CUDNN kernel when limited_workspace strategy is used.
  • no_bias::boolean, optional, default=1: Whether to disable bias parameter.
  • cudnn_tune::{None, 'fastest', 'limited_workspace', 'off'},optional, default='None': Whether to pick convolution algorithm by running performance test.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn for this layer.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None': Set layout for input, output and weight. Empty for default layout, NCW for 1d, NCHW for 2d and NCDHW for 3d.NHWC and NDHWC are only supported on GPU.

source

# MXNet.mx.DropoutMethod.

Dropout(data, p, mode, axes, cudnn_off)

Applies dropout operation to input array.

  • During training, each element of the input is set to zero with probability p. The whole array is rescaled by :math:1/(1-p) to keep the expected sum of the input unchanged.
  • During testing, this operator does not change the input if mode is 'training'. If mode is 'always', the same computaion as during training will be applied.

Example::

random.seed(998) inputarray = array([[3., 0.5, -0.5, 2., 7.], [2., -0.4, 7., 3., 0.2]]) a = symbol.Variable('a') dropout = symbol.Dropout(a, p = 0.2) executor = dropout.simplebind(a = input_array.shape)

If training

executor.forward(istrain = True, a = inputarray) executor.outputs [[ 3.75 0.625 -0. 2.5 8.75 ] [ 2.5 -0.5 8.75 3.75 0. ]]

If testing

executor.forward(istrain = False, a = inputarray) executor.outputs [[ 3. 0.5 -0.5 2. 7. ] [ 2. -0.4 7. 3. 0.2 ]]

Defined in src/operator/nn/dropout.cc:L95

Arguments

  • data::NDArray-or-SymbolicNode: Input array to which dropout will be applied.
  • p::float, optional, default=0.5: Fraction of the input that gets dropped out during training time.
  • mode::{'always', 'training'},optional, default='training': Whether to only turn on dropout during training or to also turn on for inference.
  • axes::Shape(tuple), optional, default=[]: Axes for variational dropout kernel.
  • cudnn_off::boolean or None, optional, default=0: Whether to turn off cudnn in dropout operator. This option is ignored if axes is specified.

source

# MXNet.mx.ElementWiseSumMethod.

ElementWiseSum(args)

ElementWiseSum is an alias of add_n.

Note: ElementWiseSum takes variable number of positional inputs. So instead of calling as ElementWiseSum([x, y, z], numargs=3), one should call via ElementWiseSum(x, y, z), and numargs will be determined automatically.

Adds all input arguments element-wise.

.. math:: add_n(a1, a2, ..., an) = a1 + a2 + ... + an

$add_n$ is potentially more efficient than calling $add$ by n times.

The storage type of $add_n$ output depends on storage types of inputs

  • addn(rowsparse, rowsparse, ..) = rowsparse
  • add_n(default, csr, default) = default
  • add_n(any input combinations longer than 4 (>4) with at least one default type) = default
  • otherwise, $add_n$ falls all inputs back to default storage and generates default storage

Defined in src/operator/tensor/elemwise_sum.cc:L155

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input arguments

source

# MXNet.mx.EmbeddingMethod.

Embedding(data, weight, input_dim, output_dim, dtype, sparse_grad)

Maps integer indices to vector representations (embeddings).

This operator maps words to real-valued vectors in a high-dimensional space, called word embeddings. These embeddings can capture semantic and syntactic properties of the words. For example, it has been noted that in the learned embedding spaces, similar words tend to be close to each other and dissimilar words far apart.

For an input array of shape (d1, ..., dK), the shape of an output array is (d1, ..., dK, outputdim). All the input values should be integers in the range [0, inputdim).

If the inputdim is ip0 and outputdim is op0, then shape of the embedding weight matrix must be (ip0, op0).

When "sparsegrad" is False, if any index mentioned is too large, it is replaced by the index that addresses the last vector in an embedding matrix. When "sparsegrad" is True, an error will be raised if invalid indices are found.

Examples::

inputdim = 4 outputdim = 5

// Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3) y = [[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [ 10., 11., 12., 13., 14.], [ 15., 16., 17., 18., 19.]]

// Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)] x = [[ 1., 3.], [ 0., 2.]]

// Mapped input x to its vector representation y. Embedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.], [ 15., 16., 17., 18., 19.]],

                       [[  0.,   1.,   2.,   3.,   4.],
                        [ 10.,  11.,  12.,  13.,  14.]]]

The storage type of weight can be either row_sparse or default.

.. Note::

If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be
"row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
and Adam. Note that by default lazy updates is turned on, which may perform differently
from standard updates. For more details, please check the Optimization API at:
https://mxnet.incubator.apache.org/api/python/optimization/optimization.html

Defined in src/operator/tensor/indexing_op.cc:L597

Arguments

  • data::NDArray-or-SymbolicNode: The input array to the embedding operator.
  • weight::NDArray-or-SymbolicNode: The embedding weight matrix.
  • input_dim::int, required: Vocabulary size of the input indices.
  • output_dim::int, required: Dimension of the embedding vectors.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data type of weight.
  • sparse_grad::boolean, optional, default=0: Compute row sparse gradient in the backward calculation. If set to True, the grad's storage type is row_sparse.

source

# MXNet.mx.FullyConnectedMethod.

FullyConnected(data, weight, bias, num_hidden, no_bias, flatten)

Applies a linear transformation: :math:Y = XW^T + b.

If $flatten$ is set to be true, then the shapes are:

  • data: (batch_size, x1, x2, ..., xn)
  • weight: (num_hidden, x1 * x2 * ... * xn)
  • bias: (num_hidden,)
  • out: (batch_size, num_hidden)

If $flatten$ is set to be false, then the shapes are:

  • data: (x1, x2, ..., xn, input_dim)
  • weight: (num_hidden, input_dim)
  • bias: (num_hidden,)
  • out: (x1, x2, ..., xn, num_hidden)

The learnable parameters include both $weight$ and $bias$.

If $no_bias$ is set to be true, then the $bias$ term is ignored.

.. Note::

The sparse support for FullyConnected is limited to forward evaluation with `row_sparse`
weight and bias, where the length of `weight.indices` and `bias.indices` must be equal
to `num_hidden`. This could be useful for model inference with `row_sparse` weights
trained with importance sampling or noise contrastive estimation.

To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead
of sparse.FullyConnected.

Defined in src/operator/nn/fully_connected.cc:L286

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • num_hidden::int, required: Number of hidden nodes of the output.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • flatten::boolean, optional, default=1: Whether to collapse all but the first axis of the input data tensor.

source

# MXNet.mx.GridGeneratorMethod.

GridGenerator(data, transform_type, target_shape)

Generates 2D sampling grid for bilinear sampling.

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the function.
  • transform_type::{'affine', 'warp'}, required: The type of transformation. For affine, input data should be an affine matrix of size (batch, 6). For warp, input data should be an optical flow of size (batch, 2, h, w).
  • target_shape::Shape(tuple), optional, default=[0,0]: Specifies the output shape (H, W). This is required if transformation type is affine. If transformation type is warp, this parameter is ignored.

source

# MXNet.mx.GroupNormMethod.

GroupNorm(data, gamma, beta, num_groups, eps, output_mean_var)

Group normalization.

The input channels are separated into $num_groups$ groups, each containing $num_channels / num_groups$ channels. The mean and standard-deviation are calculated separately over the each group.

.. math::

data = data.reshape((N, numgroups, C // numgroups, ...)) out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta

Both $gamma$ and $beta$ are learnable parameters.

Defined in src/operator/nn/group_norm.cc:L76

Arguments

  • data::NDArray-or-SymbolicNode: Input data
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • num_groups::int, optional, default='1': Total number of groups.
  • eps::float, optional, default=9.99999975e-06: An epsilon parameter to prevent division by 0.
  • output_mean_var::boolean, optional, default=0: Output the mean and std calculated along the given axis.

source

# MXNet.mx.IdentityAttachKLSparseRegMethod.

IdentityAttachKLSparseReg(data, sparseness_target, penalty, momentum)

Apply a sparse regularization to the output a sigmoid activation function.

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • sparseness_target::float, optional, default=0.100000001: The sparseness target
  • penalty::float, optional, default=0.00100000005: The tradeoff parameter for the sparseness penalty
  • momentum::float, optional, default=0.899999976: The momentum for running average

source

# MXNet.mx.InstanceNormMethod.

InstanceNorm(data, gamma, beta, eps)

Applies instance normalization to the n-dimensional input array.

This operator takes an n-dimensional input array where (n>2) and normalizes the input using the following formula:

.. math::

out = \frac{x - mean[data]}{ \sqrt{Var[data]} + \epsilon} * gamma + beta

This layer is similar to batch normalization layer (BatchNorm) with two differences: first, the normalization is carried out per example (instance), not over a batch. Second, the same normalization is applied both at test and train time. This operation is also known as contrast normalization.

If the input data is of shape [batch, channel, spacialdim1, spacialdim2, ...], gamma and beta parameters must be vectors of shape [channel].

This implementation is based on this paper [1]_

.. [1] Instance Normalization: The Missing Ingredient for Fast Stylization, D. Ulyanov, A. Vedaldi, V. Lempitsky, 2016 (arXiv:1607.08022v2).

Examples::

// Input of shape (2,1,2) x = [[[ 1.1, 2.2]], [[ 3.3, 4.4]]]

// gamma parameter of length 1 gamma = [1.5]

// beta parameter of length 1 beta = [0.5]

// Instance normalization is calculated with the above formula InstanceNorm(x,gamma,beta) = [[[-0.997527 , 1.99752665]], [[-0.99752653, 1.99752724]]]

Defined in src/operator/instance_norm.cc:L94

Arguments

  • data::NDArray-or-SymbolicNode: An n-dimensional input array (n > 2) of the form [batch, channel, spatialdim1, spatialdim2, ...].
  • gamma::NDArray-or-SymbolicNode: A vector of length 'channel', which multiplies the normalized input.
  • beta::NDArray-or-SymbolicNode: A vector of length 'channel', which is added to the product of the normalized input and the weight.
  • eps::float, optional, default=0.00100000005: An epsilon parameter to prevent division by 0.

source

# MXNet.mx.L2NormalizationMethod.

L2Normalization(data, eps, mode)

Normalize the input array using the L2 norm.

For 1-D NDArray, it computes::

out = data / sqrt(sum(data ** 2) + eps)

For N-D NDArray, if the input array has shape (N, N, ..., N),

with $mode$ = $instance$, it normalizes each instance in the multidimensional array by its L2 norm.::

for i in 0...N out[i,:,:,...,:] = data[i,:,:,...,:] / sqrt(sum(data[i,:,:,...,:] ** 2) + eps)

with $mode$ = $channel$, it normalizes each channel in the array by its L2 norm.::

for i in 0...N out[:,i,:,...,:] = data[:,i,:,...,:] / sqrt(sum(data[:,i,:,...,:] ** 2) + eps)

with $mode$ = $spatial$, it normalizes the cross channel norm for each position in the array by its L2 norm.::

for dim in 2...N for i in 0...N out[.....,i,...] = take(out, indices=i, axis=dim) / sqrt(sum(take(out, indices=i, axis=dim) ** 2) + eps) -dim-

Example::

x = [[[1,2], [3,4]], [[2,2], [5,6]]]

L2Normalization(x, mode='instance') =[[[ 0.18257418 0.36514837] [ 0.54772252 0.73029673]] [[ 0.24077171 0.24077171] [ 0.60192931 0.72231513]]]

L2Normalization(x, mode='channel') =[[[ 0.31622776 0.44721359] [ 0.94868326 0.89442718]] [[ 0.37139067 0.31622776] [ 0.92847669 0.94868326]]]

L2Normalization(x, mode='spatial') =[[[ 0.44721359 0.89442718] [ 0.60000002 0.80000001]] [[ 0.70710677 0.70710677] [ 0.6401844 0.76822126]]]

Defined in src/operator/l2_normalization.cc:L195

Arguments

  • data::NDArray-or-SymbolicNode: Input array to normalize.
  • eps::float, optional, default=1.00000001e-10: A small constant for numerical stability.
  • mode::{'channel', 'instance', 'spatial'},optional, default='instance': Specify the dimension along which to compute L2 norm.

source

# MXNet.mx.LRNMethod.

LRN(data, alpha, beta, knorm, nsize)

Applies local response normalization to the input.

The local response normalization layer performs "lateral inhibition" by normalizing over local input regions.

If :math:a_{x,y}^{i} is the activity of a neuron computed by applying kernel :math:i at position :math:(x, y) and then applying the ReLU nonlinearity, the response-normalized activity :math:b_{x,y}^{i} is given by the expression:

.. math:: b{x,y}^{i} = \frac{a^{i}}{\Bigg({k + \frac{\alpha}{n} \sum{j=max(0, i-\frac{n}{2})}^{min(N-1, i+\frac{n}{2})} (a^{j})^{2}}\Bigg)^{\beta}}

where the sum runs over :math:n "adjacent" kernel maps at the same spatial position, and :math:N is the total number of kernels in the layer.

Defined in src/operator/nn/lrn.cc:L157

Arguments

  • data::NDArray-or-SymbolicNode: Input data to LRN
  • alpha::float, optional, default=9.99999975e-05: The variance scaling parameter :math:lpha in the LRN expression.
  • beta::float, optional, default=0.75: The power parameter :math:eta in the LRN expression.
  • knorm::float, optional, default=2: The parameter :math:k in the LRN expression.
  • nsize::int (non-negative), required: normalization window width in elements.

source

# MXNet.mx.LayerNormMethod.

LayerNorm(data, gamma, beta, axis, eps, output_mean_var)

Layer normalization.

Normalizes the channels of the input tensor by mean and variance, and applies a scale $gamma$ as well as offset $beta$.

Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis and then compute the normalized output, which has the same shape as input, as following:

.. math::

out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta

Both $gamma$ and $beta$ are learnable parameters.

Unlike BatchNorm and InstanceNorm, the mean and var are computed along the channel dimension.

Assume the input has size k on axis 1, then both $gamma$ and $beta$ have shape (k,). If $output_mean_var$ is set to be true, then outputs both $data_mean$ and $data_std$. Note that no gradient will be passed through these two outputs.

The parameter $axis$ specifies which axis of the input shape denotes the 'channel' (separately normalized groups). The default is -1, which sets the channel axis to be the last item in the input shape.

Defined in src/operator/nn/layer_norm.cc:L201

Arguments

  • data::NDArray-or-SymbolicNode: Input data to layer normalization
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • axis::int, optional, default='-1': The axis to perform layer normalization. Usually, this should be be axis of the channel dimension. Negative values means indexing from right to left.
  • eps::float, optional, default=9.99999975e-06: An epsilon parameter to prevent division by 0.
  • output_mean_var::boolean, optional, default=0: Output the mean and std calculated along the given axis.

source

# MXNet.mx.LeakyReLUMethod.

LeakyReLU(data, gamma, act_type, slope, lower_bound, upper_bound)

Applies Leaky rectified linear unit activation element-wise to the input.

Leaky ReLUs attempt to fix the "dying ReLU" problem by allowing a small slope when the input is negative and has a slope of one when input is positive.

The following modified ReLU Activation functions are supported:

  • elu: Exponential Linear Unit. y = x > 0 ? x : slope * (exp(x)-1)
  • selu: Scaled Exponential Linear Unit. y = lambda * (x > 0 ? x : alpha * (exp(x) - 1)) where lambda = 1.0507009873554804934193349852946 and alpha = 1.6732632423543772848170429916717.
  • leaky: Leaky ReLU. y = x > 0 ? x : slope * x
  • prelu: Parametric ReLU. This is same as leaky except that slope is learnt during training.
  • rrelu: Randomized ReLU. same as leaky but the slope is uniformly and randomly chosen from [lowerbound, upperbound) for training, while fixed to be (lowerbound+upperbound)/2 for inference.

Defined in src/operator/leaky_relu.cc:L162

Arguments

  • data::NDArray-or-SymbolicNode: Input data to activation function.
  • gamma::NDArray-or-SymbolicNode: Input data to activation function.
  • act_type::{'elu', 'gelu', 'leaky', 'prelu', 'rrelu', 'selu'},optional, default='leaky': Activation function to be applied.
  • slope::float, optional, default=0.25: Init slope for the activation. (For leaky and elu only)
  • lower_bound::float, optional, default=0.125: Lower bound of random slope. (For rrelu only)
  • upper_bound::float, optional, default=0.333999991: Upper bound of random slope. (For rrelu only)

source

# MXNet.mx.LinearRegressionOutputMethod.

LinearRegressionOutput(data, label, grad_scale)

Computes and optimizes for squared loss during backward propagation. Just outputs $data$ during forward propagation.

If :math:\hat{y}_i is the predicted value of the i-th sample, and :math:y_i is the corresponding target value, then the squared loss estimated over :math:n samples is defined as

:math:\text{SquaredLoss}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert \textbf{y}_i - \hat{\textbf{y}}_i \rVert_2

.. note:: Use the LinearRegressionOutput as the final output layer of a net.

The storage type of $label$ can be $default$ or $csr$

  • LinearRegressionOutput(default, default) = default
  • LinearRegressionOutput(default, csr) = default

By default, gradients of this loss function are scaled by factor 1/m, where m is the number of regression outputs of a training example. The parameter grad_scale can be used to change this scale to grad_scale/m.

Defined in src/operator/regression_output.cc:L92

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the function.
  • label::NDArray-or-SymbolicNode: Input label to the function.
  • grad_scale::float, optional, default=1: Scale the gradient by a float factor

source

# MXNet.mx.LogisticRegressionOutputMethod.

LogisticRegressionOutput(data, label, grad_scale)

Applies a logistic function to the input.

The logistic function, also known as the sigmoid function, is computed as :math:\frac{1}{1+exp(-\textbf{x})}.

Commonly, the sigmoid is used to squash the real-valued output of a linear model :math:wTx+b into the [0,1] range so that it can be interpreted as a probability. It is suitable for binary classification or probability prediction tasks.

.. note:: Use the LogisticRegressionOutput as the final output layer of a net.

The storage type of $label$ can be $default$ or $csr$

  • LogisticRegressionOutput(default, default) = default
  • LogisticRegressionOutput(default, csr) = default

The loss function used is the Binary Cross Entropy Loss:

:math:-{(y\log(p) + (1 - y)\log(1 - p))}

Where y is the ground truth probability of positive outcome for a given example, and p the probability predicted by the model. By default, gradients of this loss function are scaled by factor 1/m, where m is the number of regression outputs of a training example. The parameter grad_scale can be used to change this scale to grad_scale/m.

Defined in src/operator/regression_output.cc:L152

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the function.
  • label::NDArray-or-SymbolicNode: Input label to the function.
  • grad_scale::float, optional, default=1: Scale the gradient by a float factor

source

# MXNet.mx.MAERegressionOutputMethod.

MAERegressionOutput(data, label, grad_scale)

Computes mean absolute error of the input.

MAE is a risk metric corresponding to the expected value of the absolute error.

If :math:\hat{y}_i is the predicted value of the i-th sample, and :math:y_i is the corresponding target value, then the mean absolute error (MAE) estimated over :math:n samples is defined as

:math:\text{MAE}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert \textbf{y}_i - \hat{\textbf{y}}_i \rVert_1

.. note:: Use the MAERegressionOutput as the final output layer of a net.

The storage type of $label$ can be $default$ or $csr$

  • MAERegressionOutput(default, default) = default
  • MAERegressionOutput(default, csr) = default

By default, gradients of this loss function are scaled by factor 1/m, where m is the number of regression outputs of a training example. The parameter grad_scale can be used to change this scale to grad_scale/m.

Defined in src/operator/regression_output.cc:L120

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the function.
  • label::NDArray-or-SymbolicNode: Input label to the function.
  • grad_scale::float, optional, default=1: Scale the gradient by a float factor

source

# MXNet.mx.MakeLossMethod.

MakeLoss(data, grad_scale, valid_thresh, normalization)

Make your own loss function in network construction.

This operator accepts a customized loss function symbol as a terminal loss and the symbol should be an operator with no backward dependency. The output of this function is the gradient of loss with respect to the input data.

For example, if you are a making a cross entropy loss function. Assume $out$ is the predicted output and $label$ is the true label, then the cross entropy can be defined as::

crossentropy = label * log(out) + (1 - label) * log(1 - out) loss = MakeLoss(crossentropy)

We will need to use $MakeLoss$ when we are creating our own loss function or we want to combine multiple loss functions. Also we may want to stop some variables' gradients from backpropagation. See more detail in $BlockGrad$ or $stop_gradient$.

In addition, we can give a scale to the loss by setting $grad_scale$, so that the gradient of the loss will be rescaled in the backpropagation.

.. note:: This operator should be used as a Symbol instead of NDArray.

Defined in src/operator/make_loss.cc:L70

Arguments

  • data::NDArray-or-SymbolicNode: Input array.
  • grad_scale::float, optional, default=1: Gradient scale as a supplement to unary and binary operators
  • valid_thresh::float, optional, default=0: clip each element in the array to 0 when it is less than $valid_thresh$. This is used when $normalization$ is set to $'valid'$.
  • normalization::{'batch', 'null', 'valid'},optional, default='null': If this is set to null, the output gradient will not be normalized. If this is set to batch, the output gradient will be divided by the batch size. If this is set to valid, the output gradient will be divided by the number of valid input elements.

source

# MXNet.mx.PadMethod.

Pad(data, mode, pad_width, constant_value)

Pads an input array with a constant or edge values of the array.

.. note:: Pad is deprecated. Use pad instead.

.. note:: Current implementation only supports 4D and 5D input arrays with padding applied only on axes 1, 2 and 3. Expects axes 4 and 5 in pad_width to be zero.

This operation pads an input array with either a constant_value or edge values along each axis of the input array. The amount of padding is specified by pad_width.

pad_width is a tuple of integer padding widths for each axis of the format $(before_1, after_1, ... , before_N, after_N)$. The pad_width should be of length $2*N$ where $N$ is the number of dimensions of the array.

For dimension $N$ of the input array, $before_N$ and $after_N$ indicates how many values to add before and after the elements of the array along dimension $N$. The widths of the higher two dimensions $before_1$, $after_1$, $before_2$, $after_2$ must be 0.

Example::

x = [[[[ 1. 2. 3.] [ 4. 5. 6.]]

     [[  7.   8.   9.]
      [ 10.  11.  12.]]]


    [[[ 11.  12.  13.]
      [ 14.  15.  16.]]

     [[ 17.  18.  19.]
      [ 20.  21.  22.]]]]

pad(x,mode="edge", pad_width=(0,0,0,0,1,1,1,1)) =

     [[[[  1.   1.   2.   3.   3.]
        [  1.   1.   2.   3.   3.]
        [  4.   4.   5.   6.   6.]
        [  4.   4.   5.   6.   6.]]

       [[  7.   7.   8.   9.   9.]
        [  7.   7.   8.   9.   9.]
        [ 10.  10.  11.  12.  12.]
        [ 10.  10.  11.  12.  12.]]]


      [[[ 11.  11.  12.  13.  13.]
        [ 11.  11.  12.  13.  13.]
        [ 14.  14.  15.  16.  16.]
        [ 14.  14.  15.  16.  16.]]

       [[ 17.  17.  18.  19.  19.]
        [ 17.  17.  18.  19.  19.]
        [ 20.  20.  21.  22.  22.]
        [ 20.  20.  21.  22.  22.]]]]

pad(x, mode="constant", constantvalue=0, padwidth=(0,0,0,0,1,1,1,1)) =

     [[[[  0.   0.   0.   0.   0.]
        [  0.   1.   2.   3.   0.]
        [  0.   4.   5.   6.   0.]
        [  0.   0.   0.   0.   0.]]

       [[  0.   0.   0.   0.   0.]
        [  0.   7.   8.   9.   0.]
        [  0.  10.  11.  12.   0.]
        [  0.   0.   0.   0.   0.]]]


      [[[  0.   0.   0.   0.   0.]
        [  0.  11.  12.  13.   0.]
        [  0.  14.  15.  16.   0.]
        [  0.   0.   0.   0.   0.]]

       [[  0.   0.   0.   0.   0.]
        [  0.  17.  18.  19.   0.]
        [  0.  20.  21.  22.   0.]
        [  0.   0.   0.   0.   0.]]]]

Defined in src/operator/pad.cc:L765

Arguments

  • data::NDArray-or-SymbolicNode: An n-dimensional input array.
  • mode::{'constant', 'edge', 'reflect'}, required: Padding type to use. "constant" pads with constant_value "edge" pads using the edge values of the input array "reflect" pads by reflecting values with respect to the edges.
  • pad_width::Shape(tuple), required: Widths of the padding regions applied to the edges of each axis. It is a tuple of integer padding widths for each axis of the format $(before_1, after_1, ... , before_N, after_N)$. It should be of length $2*N$ where $N$ is the number of dimensions of the array.This is equivalent to pad_width in numpy.pad, but flattened.
  • constant_value::double, optional, default=0: The value used for padding when mode is "constant".

source

# MXNet.mx.PoolingMethod.

Pooling(data, kernel, pool_type, global_pool, cudnn_off, pooling_convention, stride, pad, p_value, count_include_pad, layout)

Performs pooling on the input.

The shapes for 1-D pooling are

  • data and out: (batch_size, channel, width) (NCW layout) or (batch_size, width, channel) (NWC layout),

The shapes for 2-D pooling are

  • data and out: (batch_size, channel, height, width) (NCHW layout) or (batch_size, height, width, channel) (NHWC layout),

    outheight = f(height, kernel[0], pad[0], stride[0]) outwidth = f(width, kernel[1], pad[1], stride[1])

The definition of f depends on $pooling_convention$, which has two options:

  • valid (default)::

    f(x, k, p, s) = floor((x+2*p-k)/s)+1 * full, which is compatible with Caffe::

    f(x, k, p, s) = ceil((x+2*p-k)/s)+1

When $global_pool$ is set to be true, then global pooling is performed. It will reset $kernel=(height, width)$ and set the appropiate padding to 0.

Three pooling options are supported by $pool_type$:

  • avg: average pooling
  • max: max pooling
  • sum: sum pooling
  • lp: Lp pooling

For 3-D pooling, an additional depth dimension is added before height. Namely the input data and output will have shape (batch_size, channel, depth, height, width) (NCDHW layout) or (batch_size, depth, height, width, channel) (NDHWC layout).

Notes on Lp pooling:

Lp pooling was first introduced by this paper: https://arxiv.org/pdf/1204.3968.pdf. L-1 pooling is simply sum pooling, while L-inf pooling is simply max pooling. We can see that Lp pooling stands between those two, in practice the most common value for p is 2.

For each window $X$, the mathematical expression for Lp pooling is:

:math:f(X) = \sqrt[p]{\sum_{x}^{X} x^p}

Defined in src/operator/nn/pooling.cc:L416

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the pooling operator.
  • kernel::Shape(tuple), optional, default=[]: Pooling kernel size: (y, x) or (d, y, x)
  • pool_type::{'avg', 'lp', 'max', 'sum'},optional, default='max': Pooling type to be applied.
  • global_pool::boolean, optional, default=0: Ignore kernel size, do global pooling based on current input feature map.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn pooling and use MXNet pooling operator.
  • pooling_convention::{'full', 'same', 'valid'},optional, default='valid': Pooling convention to be applied.
  • stride::Shape(tuple), optional, default=[]: Stride: for pooling (y, x) or (d, y, x). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Pad for pooling: (y, x) or (d, y, x). Defaults to no padding.
  • p_value::int or None, optional, default='None': Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling.
  • count_include_pad::boolean or None, optional, default=None: Only used for AvgPool, specify whether to count padding elements for averagecalculation. For example, with a 55 kernel on a 33 corner of a image,the sum of the 9 valid elements will be divided by 25 if this is set to true,or it will be divided by 9 if this is set to false. Defaults to true.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC', 'NWC'},optional, default='None': Set layout for input and output. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.

source

# MXNet.mx.Pooling_v1Method.

Pooling_v1(data, kernel, pool_type, global_pool, pooling_convention, stride, pad)

This operator is DEPRECATED. Perform pooling on the input.

The shapes for 2-D pooling is

  • data: (batch_size, channel, height, width)
  • out: (batchsize, numfilter, outheight, outwidth), with::

    outheight = f(height, kernel[0], pad[0], stride[0]) outwidth = f(width, kernel[1], pad[1], stride[1])

The definition of f depends on $pooling_convention$, which has two options:

  • valid (default)::

    f(x, k, p, s) = floor((x+2*p-k)/s)+1 * full, which is compatible with Caffe::

    f(x, k, p, s) = ceil((x+2*p-k)/s)+1

But $global_pool$ is set to be true, then do a global pooling, namely reset $kernel=(height, width)$.

Three pooling options are supported by $pool_type$:

  • avg: average pooling
  • max: max pooling
  • sum: sum pooling

1-D pooling is special case of 2-D pooling with weight=1 and kernel[1]=1.

For 3-D pooling, an additional depth dimension is added before height. Namely the input data will have shape (batch_size, channel, depth, height, width).

Defined in src/operator/pooling_v1.cc:L103

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the pooling operator.
  • kernel::Shape(tuple), optional, default=[]: pooling kernel size: (y, x) or (d, y, x)
  • pool_type::{'avg', 'max', 'sum'},optional, default='max': Pooling type to be applied.
  • global_pool::boolean, optional, default=0: Ignore kernel size, do global pooling based on current input feature map.
  • pooling_convention::{'full', 'valid'},optional, default='valid': Pooling convention to be applied.
  • stride::Shape(tuple), optional, default=[]: stride: for pooling (y, x) or (d, y, x)
  • pad::Shape(tuple), optional, default=[]: pad for pooling: (y, x) or (d, y, x)

source

# MXNet.mx.RNNMethod.

RNN(data, parameters, state, state_cell, sequence_length, state_size, num_layers, bidirectional, mode, p, state_outputs, projection_size, lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, use_sequence_length)

Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.

When the input data is of type float32 and the environment variables MXNETCUDAALLOWTENSORCORE and MXNETCUDATENSOROPMATHALLOWCONVERSION are set to 1, this operator will try to use pseudo-float16 precision (float32 math with float16 I/O) precision in order to use Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.

Vanilla RNN

Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: ReLU and Tanh.

With ReLU activation function:

.. math:: ht = relu(W * xt + b + W{hh} * h + b_{hh})

With Tanh activtion function:

.. math:: ht = \tanh(W * xt + b + W{hh} * h + b_{hh})

Reference paper: Finding structure in time - Elman, 1988. https://crl.ucsd.edu/~elman/Papers/fsit.pdf

LSTM

Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf

.. math::

With the projection size being set, LSTM could use the projection feature to reduce the parameters size and give some speedups without significant damage to the accuracy.

Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128

.. math::

GRU

Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078

The definition of GRU here is slightly different from paper but compatible with CUDNN.

.. math::

Defined in src/operator/rnn.cc:L375

Arguments

  • data::NDArray-or-SymbolicNode: Input data to RNN
  • parameters::NDArray-or-SymbolicNode: Vector of all RNN trainable parameters concatenated
  • state::NDArray-or-SymbolicNode: initial hidden state of the RNN
  • state_cell::NDArray-or-SymbolicNode: initial cell state for LSTM networks (only for LSTM)
  • sequence_length::NDArray-or-SymbolicNode: Vector of valid sequence lengths for each element in batch. (Only used if usesequencelength kwarg is True)
  • state_size::int (non-negative), required: size of the state for each layer
  • num_layers::int (non-negative), required: number of stacked layers
  • bidirectional::boolean, optional, default=0: whether to use bidirectional recurrent layers
  • mode::{'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required: the type of RNN to compute
  • p::float, optional, default=0: drop rate of the dropout on the outputs of each RNN layer, except the last layer.
  • state_outputs::boolean, optional, default=0: Whether to have the states as symbol outputs.
  • projection_size::int or None, optional, default='None': size of project size
  • lstm_state_clip_min::double or None, optional, default=None: Minimum clip value of LSTM states. This option must be used together with lstmstateclip_max.
  • lstm_state_clip_max::double or None, optional, default=None: Maximum clip value of LSTM states. This option must be used together with lstmstateclip_min.
  • lstm_state_clip_nan::boolean, optional, default=0: Whether to stop NaN from propagating in state by clipping it to min/max. If clipping range is not specified, this option is ignored.
  • use_sequence_length::boolean, optional, default=0: If set to true, this layer takes in an extra input parameter sequence_length to specify variable length sequence

source

# MXNet.mx.ROIPoolingMethod.

ROIPooling(data, rois, pooled_size, spatial_scale)

Performs region of interest(ROI) pooling on the input array.

ROI pooling is a variant of a max pooling layer, in which the output size is fixed and region of interest is a parameter. Its purpose is to perform max pooling on the inputs of non-uniform sizes to obtain fixed-size feature maps. ROI pooling is a neural-net layer mostly used in training a Fast R-CNN network for object detection.

This operator takes a 4D feature map as an input array and region proposals as rois, then it pools over sub-regions of input and produces a fixed-sized output array regardless of the ROI size.

To crop the feature map accordingly, you can resize the bounding box coordinates by changing the parameters rois and spatial_scale.

The cropped feature maps are pooled by standard max pooling operation to a fixed size output indicated by a pooled_size parameter. batch_size will change to the number of region bounding boxes after ROIPooling.

The size of each region of interest doesn't have to be perfectly divisible by the number of pooling sections(pooled_size).

Example::

x = [[[[ 0., 1., 2., 3., 4., 5.], [ 6., 7., 8., 9., 10., 11.], [ 12., 13., 14., 15., 16., 17.], [ 18., 19., 20., 21., 22., 23.], [ 24., 25., 26., 27., 28., 29.], [ 30., 31., 32., 33., 34., 35.], [ 36., 37., 38., 39., 40., 41.], [ 42., 43., 44., 45., 46., 47.]]]]

// region of interest i.e. bounding box coordinates. y = [[0,0,0,4,4]]

// returns array of shape (2,2) according to the given roi with max pooling. ROIPooling(x, y, (2,2), 1.0) = [[[[ 14., 16.], [ 26., 28.]]]]

// region of interest is changed due to the change in spacial_scale parameter. ROIPooling(x, y, (2,2), 0.7) = [[[[ 7., 9.], [ 19., 21.]]]]

Defined in src/operator/roi_pooling.cc:L224

Arguments

  • data::NDArray-or-SymbolicNode: The input array to the pooling operator, a 4D Feature maps
  • rois::NDArray-or-SymbolicNode: Bounding box coordinates, a 2D array of [[batch*index, x1, y1, x2, y2]], where (x1, y1) and (x2, y2) are top left and bottom right corners of designated region of interest. batch*index indicates the index of corresponding image in the input array
  • pooled_size::Shape(tuple), required: ROI pooling output shape (h,w)
  • spatial_scale::float, required: Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers

source

# MXNet.mx.SVMOutputMethod.

SVMOutput(data, label, margin, regularization_coefficient, use_linear)

Computes support vector machine based transformation of the input.

This tutorial demonstrates using SVM as output layer for classification instead of softmax: https://github.com/apache/mxnet/tree/v1.x/example/svm_mnist.

Arguments

  • data::NDArray-or-SymbolicNode: Input data for SVM transformation.
  • label::NDArray-or-SymbolicNode: Class label for the input data.
  • margin::float, optional, default=1: The loss function penalizes outputs that lie outside this margin. Default margin is 1.
  • regularization_coefficient::float, optional, default=1: Regularization parameter for the SVM. This balances the tradeoff between coefficient size and error.
  • use_linear::boolean, optional, default=0: Whether to use L1-SVM objective. L2-SVM objective is used by default.

source

# MXNet.mx.SequenceLastMethod.

SequenceLast(data, sequence_length, use_sequence_length, axis)

Takes the last element of a sequence.

This function takes an n-dimensional input array of the form [maxsequencelength, batchsize, otherfeaturedims] and returns a (n-1)-dimensional array of the form [batchsize, otherfeaturedims].

Parameter sequence_length is used to handle variable-length sequences. sequence_length should be an input array of positive ints of dimension [batch*size]. To use this parameter, set use*sequence_lengthtoTrue, otherwise each example in the batch is assumed to have the max sequence length.

.. note:: Alternatively, you can also use take operator.

Example::

x = [[[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]],

    [[ 10.,   11.,   12.],
     [ 13.,   14.,   15.],
     [ 16.,   17.,   18.]],

    [[  19.,   20.,   21.],
     [  22.,   23.,   24.],
     [  25.,   26.,   27.]]]

// returns last sequence when sequence_length parameter is not used SequenceLast(x) = [[ 19., 20., 21.], [ 22., 23., 24.], [ 25., 26., 27.]]

// sequencelength is used SequenceLast(x, sequencelength=[1,1,1], usesequencelength=True) = [[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]]

// sequencelength is used SequenceLast(x, sequencelength=[1,2,3], usesequencelength=True) = [[ 1., 2., 3.], [ 13., 14., 15.], [ 25., 26., 27.]]

Defined in src/operator/sequence_last.cc:L105

Arguments

  • data::NDArray-or-SymbolicNode: n-dimensional input array of the form [maxsequencelength, batchsize, otherfeature_dims] where n>2
  • sequence_length::NDArray-or-SymbolicNode: vector of sequence lengths of the form [batch_size]
  • use_sequence_length::boolean, optional, default=0: If set to true, this layer takes in an extra input parameter sequence_length to specify variable length sequence
  • axis::int, optional, default='0': The sequence axis. Only values of 0 and 1 are currently supported.

source

# MXNet.mx.SequenceMaskMethod.

SequenceMask(data, sequence_length, use_sequence_length, value, axis)

Sets all elements outside the sequence to a constant value.

This function takes an n-dimensional input array of the form [maxsequencelength, batchsize, otherfeature_dims] and returns an array of the same shape.

Parameter sequence_length is used to handle variable-length sequences. sequence_length should be an input array of positive ints of dimension [batch*size]. To use this parameter, set use*sequence_lengthtoTrue, otherwise each example in the batch is assumed to have the max sequence length and this operator works as theidentity operator.

Example::

x = [[[ 1., 2., 3.], [ 4., 5., 6.]],

    [[  7.,   8.,   9.],
     [ 10.,  11.,  12.]],

    [[ 13.,  14.,   15.],
     [ 16.,  17.,   18.]]]

// Batch 1 B1 = [[ 1., 2., 3.], [ 7., 8., 9.], [ 13., 14., 15.]]

// Batch 2 B2 = [[ 4., 5., 6.], [ 10., 11., 12.], [ 16., 17., 18.]]

// works as identity operator when sequence_length parameter is not used SequenceMask(x) = [[[ 1., 2., 3.], [ 4., 5., 6.]],

                  [[  7.,   8.,   9.],
                   [ 10.,  11.,  12.]],

                  [[ 13.,  14.,   15.],
                   [ 16.,  17.,   18.]]]

// sequencelength [1,1] means 1 of each batch will be kept // and other rows are masked with default mask value = 0 SequenceMask(x, sequencelength=[1,1], usesequencelength=True) = [[[ 1., 2., 3.], [ 4., 5., 6.]],

             [[  0.,   0.,   0.],
              [  0.,   0.,   0.]],

             [[  0.,   0.,   0.],
              [  0.,   0.,   0.]]]

// sequencelength [2,3] means 2 of batch B1 and 3 of batch B2 will be kept // and other rows are masked with value = 1 SequenceMask(x, sequencelength=[2,3], usesequencelength=True, value=1) = [[[ 1., 2., 3.], [ 4., 5., 6.]],

             [[  7.,   8.,   9.],
              [  10.,  11.,  12.]],

             [[   1.,   1.,   1.],
              [  16.,  17.,  18.]]]

Defined in src/operator/sequence_mask.cc:L185

Arguments

  • data::NDArray-or-SymbolicNode: n-dimensional input array of the form [maxsequencelength, batchsize, otherfeature_dims] where n>2
  • sequence_length::NDArray-or-SymbolicNode: vector of sequence lengths of the form [batch_size]
  • use_sequence_length::boolean, optional, default=0: If set to true, this layer takes in an extra input parameter sequence_length to specify variable length sequence
  • value::float, optional, default=0: The value to be used as a mask.
  • axis::int, optional, default='0': The sequence axis. Only values of 0 and 1 are currently supported.

source

# MXNet.mx.SequenceReverseMethod.

SequenceReverse(data, sequence_length, use_sequence_length, axis)

Reverses the elements of each sequence.

This function takes an n-dimensional input array of the form [maxsequencelength, batchsize, otherfeature_dims] and returns an array of the same shape.

Parameter sequence_length is used to handle variable-length sequences. sequence_length should be an input array of positive ints of dimension [batch*size]. To use this parameter, set use*sequence_lengthtoTrue, otherwise each example in the batch is assumed to have the max sequence length.

Example::

x = [[[ 1., 2., 3.], [ 4., 5., 6.]],

    [[  7.,   8.,   9.],
     [ 10.,  11.,  12.]],

    [[ 13.,  14.,   15.],
     [ 16.,  17.,   18.]]]

// Batch 1 B1 = [[ 1., 2., 3.], [ 7., 8., 9.], [ 13., 14., 15.]]

// Batch 2 B2 = [[ 4., 5., 6.], [ 10., 11., 12.], [ 16., 17., 18.]]

// returns reverse sequence when sequence_length parameter is not used SequenceReverse(x) = [[[ 13., 14., 15.], [ 16., 17., 18.]],

                     [[  7.,   8.,   9.],
                      [ 10.,  11.,  12.]],

                     [[  1.,   2.,   3.],
                      [  4.,   5.,   6.]]]

// sequencelength [2,2] means 2 rows of // both batch B1 and B2 will be reversed. SequenceReverse(x, sequencelength=[2,2], usesequencelength=True) = [[[ 7., 8., 9.], [ 10., 11., 12.]],

                  [[  1.,   2.,   3.],
                   [  4.,   5.,   6.]],

                  [[ 13.,  14.,   15.],
                   [ 16.,  17.,   18.]]]

// sequencelength [2,3] means 2 of batch B2 and 3 of batch B3 // will be reversed. SequenceReverse(x, sequencelength=[2,3], usesequencelength=True) = [[[ 7., 8., 9.], [ 16., 17., 18.]],

                 [[  1.,   2.,   3.],
                  [ 10.,  11.,  12.]],

                 [[ 13.,  14,   15.],
                  [  4.,   5.,   6.]]]

Defined in src/operator/sequence_reverse.cc:L121

Arguments

  • data::NDArray-or-SymbolicNode: n-dimensional input array of the form [maxsequencelength, batch_size, other dims] where n>2
  • sequence_length::NDArray-or-SymbolicNode: vector of sequence lengths of the form [batch_size]
  • use_sequence_length::boolean, optional, default=0: If set to true, this layer takes in an extra input parameter sequence_length to specify variable length sequence
  • axis::int, optional, default='0': The sequence axis. Only 0 is currently supported.

source

# MXNet.mx.SliceChannelMethod.

SliceChannel(data, num_outputs, axis, squeeze_axis)

Splits an array along a particular axis into multiple sub-arrays.

.. note:: $SliceChannel$ is deprecated. Use $split$ instead.

Note that num_outputs should evenly divide the length of the axis along which to split the array.

Example::

x = [[[ 1.] [ 2.]] [[ 3.] [ 4.]] [[ 5.] [ 6.]]] x.shape = (3, 2, 1)

y = split(x, axis=1, num_outputs=2) // a list of 2 arrays with shape (3, 1, 1) y = [[[ 1.]] [[ 3.]] [[ 5.]]]

   [[[ 2.]]
    [[ 4.]]
    [[ 6.]]]

y[0].shape = (3, 1, 1)

z = split(x, axis=0, num_outputs=3) // a list of 3 arrays with shape (1, 2, 1) z = [[[ 1.] [ 2.]]]

   [[[ 3.]
     [ 4.]]]

   [[[ 5.]
     [ 6.]]]

z[0].shape = (1, 2, 1)

squeeze_axis=1 removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $1$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to true only if $input.shape[axis] == num_outputs$.

Example::

z = split(x, axis=0, numoutputs=3, squeezeaxis=1) // a list of 3 arrays with shape (2, 1) z = [[ 1.] [ 2.]]

   [[ 3.]
    [ 4.]]

   [[ 5.]
    [ 6.]]

z[0].shape = (2 ,1 )

Defined in src/operator/slice_channel.cc:L106

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • num_outputs::int, required: Number of splits. Note that this should evenly divide the length of the axis.
  • axis::int, optional, default='1': Axis along which to split.
  • squeeze_axis::boolean, optional, default=0: If true, Removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $true$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to $true$ only if $input.shape[axis] == num_outputs$.

source

# MXNet.mx.SoftmaxActivationMethod.

SoftmaxActivation(data, mode)

Applies softmax activation to input. This is intended for internal layers.

.. note::

This operator has been deprecated, please use softmax.

If mode = $instance$, this operator will compute a softmax for each instance in the batch. This is the default mode.

If mode = $channel$, this operator will compute a k-class softmax at each position of each instance, where k = $num_channel$. This mode can only be used when the input array has at least 3 dimensions. This can be used for fully convolutional network, image segmentation, etc.

Example::

inputarray = mx.nd.array([[3., 0.5, -0.5, 2., 7.], [2., -.4, 7., 3., 0.2]]) softmaxact = mx.nd.SoftmaxActivation(inputarray) print softmaxact.asnumpy()

[[ 1.78322066e-02 1.46375655e-03 5.38485940e-04 6.56010211e-03 9.73605454e-01] [ 6.56221947e-03 5.95310994e-04 9.73919690e-01 1.78379621e-02 1.08472735e-03]]

Defined in src/operator/nn/softmax_activation.cc:L58

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • mode::{'channel', 'instance'},optional, default='instance': Specifies how to compute the softmax. If set to $instance$, it computes softmax for each instance. If set to $channel$, It computes cross channel softmax for each position of each instance.

source

# MXNet.mx.SoftmaxOutputMethod.

SoftmaxOutput(data, label, grad_scale, ignore_label, multi_output, use_ignore, preserve_shape, normalization, out_grad, smooth_alpha)

Computes the gradient of cross entropy loss with respect to softmax output.

  • This operator computes the gradient in two steps. The cross entropy loss does not actually need to be computed.

    • Applies softmax function on the input array.
    • Computes and returns the gradient of cross entropy loss w.r.t. the softmax output.
    • The softmax function, cross entropy loss and gradient is given by:

    • Softmax Function:

      .. math:: \text{softmax}(x)i = \frac{exp(xi)}{\sumj exp(xj)} * Cross Entropy Function:

      .. math:: \text{CE(label, output)} = - \sumi \text{label}i \log(\text{output}_i) * The gradient of cross entropy loss w.r.t softmax output:

      .. math:: \text{gradient} = \text{output} - \text{label} * During forward propagation, the softmax function is computed for each instance in the input array.

    For general N-D input arrays with shape :math:(d_1, d_2, ..., d_n). The size is :math:s=d_1 \cdot d_2 \cdot \cdot \cdot d_n. We can use the parameters preserve_shape and multi_output to specify the way to compute softmax:

    • By default, preserve_shape is $false$. This operator will reshape the input array into a 2-D array with shape :math:(d_1, \frac{s}{d_1}) and then compute the softmax function for each row in the reshaped array, and afterwards reshape it back to the original shape :math:(d_1, d_2, ..., d_n).
    • If preserve_shape is $true$, the softmax function will be computed along the last axis (axis = $-1$).
    • If multi_output is $true$, the softmax function will be computed along the second axis (axis = $1$).
    • During backward propagation, the gradient of cross-entropy loss w.r.t softmax output array is computed. The provided label can be a one-hot label array or a probability label array.

    • If the parameter use_ignore is $true$, ignore_label can specify input instances with a particular label to be ignored during backward propagation. This has no effect when softmax output has same shape as label.

      Example::

      data = [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]] label = [1,0,2,3] ignorelabel = 1 SoftmaxOutput(data=data, label = label, multioutput=true, useignore=true, ignorelabel=ignore_label)

      forward softmax output

      [[ 0.0320586 0.08714432 0.23688284 0.64391428] [ 0.25 0.25 0.25 0.25 ] [ 0.25 0.25 0.25 0.25 ] [ 0.25 0.25 0.25 0.25 ]]

      backward gradient output

      [[ 0. 0. 0. 0. ] [-0.75 0.25 0.25 0.25] [ 0.25 0.25 -0.75 0.25] [ 0.25 0.25 0.25 -0.75]]

      notice that the first row is all 0 because label[0] is 1, which is equal to ignore_label.

        * The parameter `grad_scale` can be used to rescale the gradient, which is often used to give each loss function different weights.
        * This operator also supports various ways to normalize the gradient by `normalization`, The `normalization` is applied if softmax output has different shape than the labels. The `normalization` mode can be set to the followings:
      
      • $'null'$: do nothing.
      • $'batch'$: divide the gradient by the batch size.
      • $'valid'$: divide the gradient by the number of instances which are not ignored.

Defined in src/operator/softmax_output.cc:L242

Arguments

  • data::NDArray-or-SymbolicNode: Input array.
  • label::NDArray-or-SymbolicNode: Ground truth label.
  • grad_scale::float, optional, default=1: Scales the gradient by a float factor.
  • ignore_label::float, optional, default=-1: The instances whose labels == ignore_label will be ignored during backward, if use_ignore is set to $true$).
  • multi_output::boolean, optional, default=0: If set to $true$, the softmax function will be computed along axis $1$. This is applied when the shape of input array differs from the shape of label array.
  • use_ignore::boolean, optional, default=0: If set to $true$, the ignore_label value will not contribute to the backward gradient.
  • preserve_shape::boolean, optional, default=0: If set to $true$, the softmax function will be computed along the last axis ($-1$).
  • normalization::{'batch', 'null', 'valid'},optional, default='null': Normalizes the gradient.
  • out_grad::boolean, optional, default=0: Multiplies gradient with output gradient element-wise.
  • smooth_alpha::float, optional, default=0: Constant for computing a label smoothed version of cross-entropyfor the backwards pass. This constant gets subtracted from theone-hot encoding of the gold label and distributed uniformly toall other labels.

source

# MXNet.mx.SpatialTransformerMethod.

SpatialTransformer(data, loc, target_shape, transform_type, sampler_type, cudnn_off)

Applies a spatial transformer to input feature map.

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the SpatialTransformerOp.
  • loc::NDArray-or-SymbolicNode: localisation net, the output dim should be 6 when transform_type is affine. You shold initialize the weight and bias with identity tranform.
  • target_shape::Shape(tuple), optional, default=[0,0]: output shape(h, w) of spatial transformer: (y, x)
  • transform_type::{'affine'}, required: transformation type
  • sampler_type::{'bilinear'}, required: sampling type
  • cudnn_off::boolean or None, optional, default=None: whether to turn cudnn off

source

# MXNet.mx.SwapAxisMethod.

SwapAxis(data, dim1, dim2)

Interchanges two axes of an array.

Examples::

x = [[1, 2, 3]]) swapaxes(x, 0, 1) = [[ 1], [ 2], [ 3]]

x = [[[ 0, 1], [ 2, 3]], [[ 4, 5], [ 6, 7]]] // (2,2,2) array

swapaxes(x, 0, 2) = [[[ 0, 4], [ 2, 6]], [[ 1, 5], [ 3, 7]]]

Defined in src/operator/swapaxis.cc:L69

Arguments

  • data::NDArray-or-SymbolicNode: Input array.
  • dim1::int, optional, default='0': the first axis to be swapped.
  • dim2::int, optional, default='0': the second axis to be swapped.

source

# MXNet.mx.UpSamplingMethod.

UpSampling(data, scale, num_filter, sample_type, multi_input_mode, num_args, workspace)

Note: UpSampling takes variable number of positional inputs. So instead of calling as UpSampling([x, y, z], numargs=3), one should call via UpSampling(x, y, z), and numargs will be determined automatically.

Upsamples the given input data.

Two algorithms ($sample_type$) are available for upsampling:

  • Nearest Neighbor
  • Bilinear

Nearest Neighbor Upsampling

Input data is expected to be NCHW.

Example::

x = [[[[1. 1. 1.] [1. 1. 1.] [1. 1. 1.]]]]

UpSampling(x, scale=2, sample_type='nearest') = [[[[1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.]]]]

Bilinear Upsampling

Uses deconvolution algorithm under the hood. You need provide both input data and the kernel.

Input data is expected to be NCHW.

num_filter is expected to be same as the number of channels.

Example::

x = [[[[1. 1. 1.] [1. 1. 1.] [1. 1. 1.]]]]

w = [[[[1. 1. 1. 1.] [1. 1. 1. 1.] [1. 1. 1. 1.] [1. 1. 1. 1.]]]]

UpSampling(x, w, scale=2, sampletype='bilinear', numfilter=1) = [[[[1. 2. 2. 2. 2. 1.] [2. 4. 4. 4. 4. 2.] [2. 4. 4. 4. 4. 2.] [2. 4. 4. 4. 4. 2.] [2. 4. 4. 4. 4. 2.] [1. 2. 2. 2. 2. 1.]]]]

Defined in src/operator/nn/upsampling.cc:L172

Arguments

  • data::NDArray-or-SymbolicNode[]: Array of tensors to upsample. For bilinear upsampling, there should be 2 inputs - 1 data and 1 weight.
  • scale::int, required: Up sampling scale
  • num_filter::int, optional, default='0': Input filter. Only used by bilinear sampletype.Since bilinear upsampling uses deconvolution, numfilters is set to the number of channels.
  • sample_type::{'bilinear', 'nearest'}, required: upsampling method
  • multi_input_mode::{'concat', 'sum'},optional, default='concat': How to handle multiple input. concat means concatenate upsampled images along the channel dimension. sum means add all images together, only available for nearest neighbor upsampling.
  • num_args::int, required: Number of inputs to be upsampled. For nearest neighbor upsampling, this can be 1-N; the size of output will be(scaleh_0,scalew_0) and all other inputs will be upsampled to thesame size. For bilinear upsampling this must be 2; 1 input and 1 weight.
  • workspace::long (non-negative), optional, default=512: Tmp workspace for deconvolution (MB)

source

# MXNet.mx._CachedOpMethod.

_CachedOp(data)

Arguments

  • data::NDArray-or-SymbolicNode[]: input data list

source

# MXNet.mx._CachedOpThreadSafeMethod.

_CachedOpThreadSafe(data)

Arguments

  • data::NDArray-or-SymbolicNode[]: input data list

source

# MXNet.mx._CrossDeviceCopyMethod.

_CrossDeviceCopy()

Special op to copy data cross device

Arguments

source

# MXNet.mx._CustomFunctionMethod.

_CustomFunction()

Arguments

source

# MXNet.mx._DivMethod.

_Div(lhs, rhs)

Div is an alias of elemwisediv.

Divides arguments element-wise.

The storage type of $elemwise_div$ output is always dense

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._DivScalarMethod.

_DivScalar(data, scalar, is_int)

DivScalar is an alias of _divscalar.

Divide an array with a scalar.

$_div_scalar$ only operates on data array of input if input is sparse.

For example, if input of shape (100, 100) has only 2 non zero elements, i.e. input.data = [5, 6], scalar = nan, it will result output.data = [nan, nan] instead of 10000 nans.

Defined in src/operator/tensor/elemwisebinaryscalaropbasic.cc:L174

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._EqualMethod.

_Equal(lhs, rhs)

_Equal is an alias of _equal.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._EqualScalarMethod.

_EqualScalar(data, scalar, is_int)

EqualScalar is an alias of _equalscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._GreaterMethod.

_Greater(lhs, rhs)

_Greater is an alias of _greater.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._GreaterEqualScalarMethod.

_GreaterEqualScalar(data, scalar, is_int)

GreaterEqualScalar is an alias of _greaterequal_scalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._GreaterScalarMethod.

_GreaterScalar(data, scalar, is_int)

GreaterScalar is an alias of _greaterscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._Greater_EqualMethod.

_Greater_Equal(lhs, rhs)

GreaterEqual is an alias of greaterequal.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._HypotMethod.

_Hypot(lhs, rhs)

_Hypot is an alias of _hypot.

Given the "legs" of a right triangle, return its hypotenuse.

Defined in src/operator/tensor/elemwisebinaryop_extended.cc:L78

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._HypotScalarMethod.

_HypotScalar(data, scalar, is_int)

HypotScalar is an alias of _hypotscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._LesserMethod.

_Lesser(lhs, rhs)

_Lesser is an alias of _lesser.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._LesserEqualScalarMethod.

_LesserEqualScalar(data, scalar, is_int)

LesserEqualScalar is an alias of _lesserequal_scalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._LesserScalarMethod.

_LesserScalar(data, scalar, is_int)

LesserScalar is an alias of _lesserscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._Lesser_EqualMethod.

_Lesser_Equal(lhs, rhs)

LesserEqual is an alias of lesserequal.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._LogicalAndScalarMethod.

_LogicalAndScalar(data, scalar, is_int)

LogicalAndScalar is an alias of _logicaland_scalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._LogicalOrScalarMethod.

_LogicalOrScalar(data, scalar, is_int)

LogicalOrScalar is an alias of _logicalor_scalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._LogicalXorScalarMethod.

_LogicalXorScalar(data, scalar, is_int)

LogicalXorScalar is an alias of _logicalxor_scalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._Logical_AndMethod.

_Logical_And(lhs, rhs)

LogicalAnd is an alias of logicaland.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._Logical_OrMethod.

_Logical_Or(lhs, rhs)

LogicalOr is an alias of logicalor.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._Logical_XorMethod.

_Logical_Xor(lhs, rhs)

LogicalXor is an alias of logicalxor.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._MaximumMethod.

_Maximum(lhs, rhs)

_Maximum is an alias of _maximum.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._MaximumScalarMethod.

_MaximumScalar(data, scalar, is_int)

MaximumScalar is an alias of _maximumscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._MinimumMethod.

_Minimum(lhs, rhs)

_Minimum is an alias of _minimum.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._MinimumScalarMethod.

_MinimumScalar(data, scalar, is_int)

MinimumScalar is an alias of _minimumscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._MinusScalarMethod.

_MinusScalar(data, scalar, is_int)

MinusScalar is an alias of _minusscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._ModScalarMethod.

_ModScalar(data, scalar, is_int)

ModScalar is an alias of _modscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._MulMethod.

_Mul(lhs, rhs)

Mul is an alias of elemwisemul.

Multiplies arguments element-wise.

The storage type of $elemwise_mul$ output depends on storage types of inputs

  • elemwise_mul(default, default) = default
  • elemwisemul(rowsparse, rowsparse) = rowsparse
  • elemwisemul(default, rowsparse) = row_sparse
  • elemwisemul(rowsparse, default) = row_sparse
  • elemwise_mul(csr, csr) = csr
  • otherwise, $elemwise_mul$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._MulScalarMethod.

_MulScalar(data, scalar, is_int)

MulScalar is an alias of _mulscalar.

Multiply an array with a scalar.

$_mul_scalar$ only operates on data array of input if input is sparse.

For example, if input of shape (100, 100) has only 2 non zero elements, i.e. input.data = [5, 6], scalar = nan, it will result output.data = [nan, nan] instead of 10000 nans.

Defined in src/operator/tensor/elemwisebinaryscalaropbasic.cc:L152

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._NDArrayMethod.

_NDArray(data, info)

Stub for implementing an operator implemented in native frontend language with ndarray.

Arguments

  • data::NDArray-or-SymbolicNode[]: Input data for the custom operator.
  • info::ptr, required:

source

# MXNet.mx._NativeMethod.

_Native(data, info, need_top_grad)

Stub for implementing an operator implemented in native frontend language.

Arguments

  • data::NDArray-or-SymbolicNode[]: Input data for the custom operator.
  • info::ptr, required:
  • need_top_grad::boolean, optional, default=1: Whether this layer needs out grad for backward. Should be false for loss layers.

source

# MXNet.mx._NoGradientMethod.

_NoGradient()

Place holder for variable who cannot perform gradient

Arguments

source

# MXNet.mx._NotEqualScalarMethod.

_NotEqualScalar(data, scalar, is_int)

NotEqualScalar is an alias of _notequal_scalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._Not_EqualMethod.

_Not_Equal(lhs, rhs)

NotEqual is an alias of notequal.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._PlusScalarMethod.

_PlusScalar(data, scalar, is_int)

PlusScalar is an alias of _plusscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._PowerMethod.

_Power(lhs, rhs)

_Power is an alias of _power.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._PowerScalarMethod.

_PowerScalar(data, scalar, is_int)

PowerScalar is an alias of _powerscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._RDivScalarMethod.

_RDivScalar(data, scalar, is_int)

RDivScalar is an alias of _rdivscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._RMinusScalarMethod.

_RMinusScalar(data, scalar, is_int)

RMinusScalar is an alias of _rminusscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._RModScalarMethod.

_RModScalar(data, scalar, is_int)

RModScalar is an alias of _rmodscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._RPowerScalarMethod.

_RPowerScalar(data, scalar, is_int)

RPowerScalar is an alias of _rpowerscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._adamw_updateMethod.

_adamw_update(weight, grad, mean, var, rescale_grad, lr, beta1, beta2, epsilon, wd, eta, clip_gradient)

Update function for AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v are 1st and 2nd order moment estimates (mean and variance).

.. math::

gt = \nabla J(W)\ mt = \beta1 m{t-1} + (1 - \beta1) gt\ vt = \beta2 v + (1 - \beta2) gt^2\ Wt = W - \etat (\alpha \frac{ mt }{ \sqrt{ vt } + \epsilon } + wd W)

It updates the weights using::

m = beta1m + (1-beta1)grad v = beta2v + (1-beta2)(grad**2) w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescalegrad * grad. If rescalegrad is NaN, Inf, or 0, the update is skipped.

Defined in src/operator/contrib/adamw.cc:L100

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mean::NDArray-or-SymbolicNode: Moving mean
  • var::NDArray-or-SymbolicNode: Moving variance
  • rescale_grad::NDArray-or-SymbolicNode: Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped.
  • lr::float, required: Learning rate
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • eta::float, required: Learning rate schedule multiplier
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx._addMethod.

_add(lhs, rhs)

add is an alias of elemwiseadd.

Adds arguments element-wise.

The storage type of $elemwise_add$ output depends on storage types of inputs

  • elemwiseadd(rowsparse, rowsparse) = rowsparse
  • elemwise_add(csr, csr) = csr
  • elemwise_add(default, csr) = default
  • elemwise_add(csr, default) = default
  • elemwise_add(default, rsp) = default
  • elemwise_add(rsp, default) = default
  • otherwise, $elemwise_add$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._arangeMethod.

_arange(start, stop, step, repeat, infer_range, ctx, dtype)

Return evenly spaced values within a given interval. Similar to Numpy

Arguments

  • start::double, required: Start of interval. The interval includes this value. The default start value is 0.
  • stop::double or None, optional, default=None: End of interval. The interval does not include this value, except in some cases where step is not an integer and floating point round-off affects the length of out.
  • step::double, optional, default=1: Spacing between values.
  • repeat::int, optional, default='1': The repeating time of all elements. E.g repeat=3, the element a will be repeated three times –> a, a, a.
  • infer_range::boolean, optional, default=0: When set to True, infer the stop position from the start, step, repeat, and output tensor size.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._backward_ActivationMethod.

_backward_Activation()

Arguments

source

# MXNet.mx._backward_BatchNormMethod.

_backward_BatchNorm()

Arguments

source

# MXNet.mx._backward_BatchNorm_v1Method.

_backward_BatchNorm_v1()

Arguments

source

# MXNet.mx._backward_BilinearSamplerMethod.

_backward_BilinearSampler()

Arguments

source

# MXNet.mx._backward_CachedOpMethod.

_backward_CachedOp()

Arguments

source

# MXNet.mx._backward_ConcatMethod.

_backward_Concat()

Arguments

source

# MXNet.mx._backward_ConvolutionMethod.

_backward_Convolution()

Arguments

source

# MXNet.mx._backward_Convolution_v1Method.

_backward_Convolution_v1()

Arguments

source

# MXNet.mx._backward_CorrelationMethod.

_backward_Correlation()

Arguments

source

# MXNet.mx._backward_CropMethod.

_backward_Crop()

Arguments

source

# MXNet.mx._backward_CustomMethod.

_backward_Custom()

Arguments

source

# MXNet.mx._backward_CustomFunctionMethod.

_backward_CustomFunction()

Arguments

source

# MXNet.mx._backward_DeconvolutionMethod.

_backward_Deconvolution()

Arguments

source

# MXNet.mx._backward_DropoutMethod.

_backward_Dropout()

Arguments

source

# MXNet.mx._backward_EmbeddingMethod.

_backward_Embedding()

Arguments

source

# MXNet.mx._backward_FullyConnectedMethod.

_backward_FullyConnected()

Arguments

source

# MXNet.mx._backward_GridGeneratorMethod.

_backward_GridGenerator()

Arguments

source

# MXNet.mx._backward_GroupNormMethod.

_backward_GroupNorm()

Arguments

source

# MXNet.mx._backward_IdentityAttachKLSparseRegMethod.

_backward_IdentityAttachKLSparseReg()

Arguments

source

# MXNet.mx._backward_InstanceNormMethod.

_backward_InstanceNorm()

Arguments

source

# MXNet.mx._backward_L2NormalizationMethod.

_backward_L2Normalization()

Arguments

source

# MXNet.mx._backward_LRNMethod.

_backward_LRN()

Arguments

source

# MXNet.mx._backward_LayerNormMethod.

_backward_LayerNorm()

Arguments

source

# MXNet.mx._backward_LeakyReLUMethod.

_backward_LeakyReLU()

Arguments

source

# MXNet.mx._backward_MakeLossMethod.

_backward_MakeLoss()

Arguments

source

# MXNet.mx._backward_PadMethod.

_backward_Pad()

Arguments

source

# MXNet.mx._backward_PoolingMethod.

_backward_Pooling()

Arguments

source

# MXNet.mx._backward_Pooling_v1Method.

_backward_Pooling_v1()

Arguments

source

# MXNet.mx._backward_RNNMethod.

_backward_RNN()

Arguments

source

# MXNet.mx._backward_ROIAlignMethod.

_backward_ROIAlign()

Arguments

source

# MXNet.mx._backward_ROIPoolingMethod.

_backward_ROIPooling()

Arguments

source

# MXNet.mx._backward_RROIAlignMethod.

_backward_RROIAlign()

Arguments

source

# MXNet.mx._backward_SVMOutputMethod.

_backward_SVMOutput()

Arguments

source

# MXNet.mx._backward_SequenceLastMethod.

_backward_SequenceLast()

Arguments

source

# MXNet.mx._backward_SequenceMaskMethod.

_backward_SequenceMask()

Arguments

source

# MXNet.mx._backward_SequenceReverseMethod.

_backward_SequenceReverse()

Arguments

source

# MXNet.mx._backward_SliceChannelMethod.

_backward_SliceChannel()

Arguments

source

# MXNet.mx._backward_SoftmaxActivationMethod.

_backward_SoftmaxActivation()

Arguments

source

# MXNet.mx._backward_SoftmaxOutputMethod.

_backward_SoftmaxOutput()

Arguments

source

# MXNet.mx._backward_SparseEmbeddingMethod.

_backward_SparseEmbedding()

Arguments

source

# MXNet.mx._backward_SpatialTransformerMethod.

_backward_SpatialTransformer()

Arguments

source

# MXNet.mx._backward_SwapAxisMethod.

_backward_SwapAxis()

Arguments

source

# MXNet.mx._backward_UpSamplingMethod.

_backward_UpSampling()

Arguments

source

# MXNet.mx._backward__CrossDeviceCopyMethod.

_backward__CrossDeviceCopy()

Arguments

source

# MXNet.mx._backward__NDArrayMethod.

_backward__NDArray()

Arguments

source

# MXNet.mx._backward__NativeMethod.

_backward__Native()

Arguments

source

# MXNet.mx._backward__contrib_DeformableConvolutionMethod.

_backward__contrib_DeformableConvolution()

Arguments

source

# MXNet.mx._backward__contrib_DeformablePSROIPoolingMethod.

_backward__contrib_DeformablePSROIPooling()

Arguments

source

# MXNet.mx._backward__contrib_ModulatedDeformableConvolutionMethod.

_backward__contrib_ModulatedDeformableConvolution()

Arguments

source

# MXNet.mx._backward__contrib_MultiBoxDetectionMethod.

_backward__contrib_MultiBoxDetection()

Arguments

source

# MXNet.mx._backward__contrib_MultiBoxPriorMethod.

_backward__contrib_MultiBoxPrior()

Arguments

source

# MXNet.mx._backward__contrib_MultiBoxTargetMethod.

_backward__contrib_MultiBoxTarget()

Arguments

source

# MXNet.mx._backward__contrib_MultiProposalMethod.

_backward__contrib_MultiProposal()

Arguments

source

# MXNet.mx._backward__contrib_PSROIPoolingMethod.

_backward__contrib_PSROIPooling()

Arguments

source

# MXNet.mx._backward__contrib_ProposalMethod.

_backward__contrib_Proposal()

Arguments

source

# MXNet.mx._backward__contrib_SyncBatchNormMethod.

_backward__contrib_SyncBatchNorm()

Arguments

source

# MXNet.mx._backward__contrib_count_sketchMethod.

_backward__contrib_count_sketch()

Arguments

source

# MXNet.mx._backward__contrib_fftMethod.

_backward__contrib_fft()

Arguments

source

# MXNet.mx._backward__contrib_ifftMethod.

_backward__contrib_ifft()

Arguments

source

# MXNet.mx._backward_absMethod.

_backward_abs(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_addMethod.

_backward_add()

Arguments

source

# MXNet.mx._backward_amp_castMethod.

_backward_amp_cast()

Arguments

source

# MXNet.mx._backward_amp_multicastMethod.

_backward_amp_multicast(grad, num_outputs, cast_narrow)

Arguments

  • grad::NDArray-or-SymbolicNode[]: Gradients
  • num_outputs::int, required: Number of input/output pairs to be casted to the widest type.
  • cast_narrow::boolean, optional, default=0: Whether to cast to the narrowest type

source

# MXNet.mx._backward_arccosMethod.

_backward_arccos(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_arccoshMethod.

_backward_arccosh(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_arcsinMethod.

_backward_arcsin(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_arcsinhMethod.

_backward_arcsinh(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_arctanMethod.

_backward_arctan(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_arctanhMethod.

_backward_arctanh(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_backward_FullyConnectedMethod.

_backward_backward_FullyConnected()

Arguments

source

# MXNet.mx._backward_broadcast_addMethod.

_backward_broadcast_add()

Arguments

source

# MXNet.mx._backward_broadcast_divMethod.

_backward_broadcast_div()

Arguments

source

# MXNet.mx._backward_broadcast_exponentialMethod.

_backward_broadcast_exponential(scale, size, ctx)

Arguments

  • scale::float or None, optional, default=1:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._backward_broadcast_gumbelMethod.

_backward_broadcast_gumbel(loc, scale, size, ctx)

Arguments

  • loc::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._backward_broadcast_hypotMethod.

_backward_broadcast_hypot()

Arguments

source

# MXNet.mx._backward_broadcast_logisticMethod.

_backward_broadcast_logistic(loc, scale, size, ctx)

Arguments

  • loc::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._backward_broadcast_maximumMethod.

_backward_broadcast_maximum()

Arguments

source

# MXNet.mx._backward_broadcast_minimumMethod.

_backward_broadcast_minimum()

Arguments

source

# MXNet.mx._backward_broadcast_modMethod.

_backward_broadcast_mod()

Arguments

source

# MXNet.mx._backward_broadcast_mulMethod.

_backward_broadcast_mul()

Arguments

source

# MXNet.mx._backward_broadcast_normalMethod.

_backward_broadcast_normal(loc, scale, size, ctx, dtype)

Arguments

  • loc::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'float16', 'float32', 'float64'},optional, default='float32': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._backward_broadcast_paretoMethod.

_backward_broadcast_pareto(a, size, ctx)

Arguments

  • a::float or None, optional, default=None:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._backward_broadcast_powerMethod.

_backward_broadcast_power()

Arguments

source

# MXNet.mx._backward_broadcast_rayleighMethod.

_backward_broadcast_rayleigh(scale, size, ctx)

Arguments

  • scale::float or None, optional, default=1:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._backward_broadcast_subMethod.

_backward_broadcast_sub()

Arguments

source

# MXNet.mx._backward_broadcast_weibullMethod.

_backward_broadcast_weibull(a, size, ctx)

Arguments

  • a::float or None, optional, default=None:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._backward_castMethod.

_backward_cast()

Arguments

source

# MXNet.mx._backward_cbrtMethod.

_backward_cbrt(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_clipMethod.

_backward_clip()

Arguments

source

# MXNet.mx._backward_col2imMethod.

_backward_col2im()

Arguments

source

# MXNet.mx._backward_condMethod.

_backward_cond()

Arguments

source

# MXNet.mx._backward_contrib_AdaptiveAvgPooling2DMethod.

_backward_contrib_AdaptiveAvgPooling2D()

Arguments

source

# MXNet.mx._backward_contrib_BatchNormWithReLUMethod.

_backward_contrib_BatchNormWithReLU()

Arguments

source

# MXNet.mx._backward_contrib_BilinearResize2DMethod.

_backward_contrib_BilinearResize2D()

Arguments

source

# MXNet.mx._backward_contrib_bipartite_matchingMethod.

_backward_contrib_bipartite_matching(is_ascend, threshold, topk)

Arguments

  • is_ascend::boolean, optional, default=0: Use ascend order for scores instead of descending. Please set threshold accordingly.
  • threshold::float, required: Ignore matching when score < thresh, if isascend=false, or ignore score > thresh, if isascend=true.
  • topk::int, optional, default='-1': Limit the number of matches to topk, set -1 for no limit

source

# MXNet.mx._backward_contrib_boolean_maskMethod.

_backward_contrib_boolean_mask(axis)

Arguments

  • axis::int, optional, default='0': An integer that represents the axis in NDArray to mask from.

source

# MXNet.mx._backward_contrib_box_iouMethod.

_backward_contrib_box_iou(format)

Arguments

  • format::{'center', 'corner'},optional, default='corner': The box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

source

# MXNet.mx._backward_contrib_box_nmsMethod.

_backward_contrib_box_nms(overlap_thresh, valid_thresh, topk, coord_start, score_index, id_index, background_id, force_suppress, in_format, out_format)

Arguments

  • overlap_thresh::float, optional, default=0.5: Overlapping(IoU) threshold to suppress object with smaller score.
  • valid_thresh::float, optional, default=0: Filter input boxes to those whose scores greater than valid_thresh.
  • topk::int, optional, default='-1': Apply nms to topk boxes with descending scores, -1 to no restriction.
  • coord_start::int, optional, default='2': Start index of the consecutive 4 coordinates.
  • score_index::int, optional, default='1': Index of the scores/confidence of boxes.
  • id_index::int, optional, default='-1': Optional, index of the class categories, -1 to disable.
  • background_id::int, optional, default='-1': Optional, id of the background class which will be ignored in nms.
  • force_suppress::boolean, optional, default=0: Optional, if set false and id_index is provided, nms will only apply to boxes belongs to the same category
  • in_format::{'center', 'corner'},optional, default='corner': The input box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

  • out_format::{'center', 'corner'},optional, default='corner': The output box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

source

# MXNet.mx._backward_copyMethod.

_backward_copy()

Arguments

source

# MXNet.mx._backward_cosMethod.

_backward_cos(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_coshMethod.

_backward_cosh(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_ctc_lossMethod.

_backward_ctc_loss()

Arguments

source

# MXNet.mx._backward_degreesMethod.

_backward_degrees(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_diagMethod.

_backward_diag()

Arguments

source

# MXNet.mx._backward_divMethod.

_backward_div()

Arguments

source

# MXNet.mx._backward_div_scalarMethod.

_backward_div_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_dotMethod.

_backward_dot(transpose_a, transpose_b, forward_stype)

Arguments

  • transpose_a::boolean, optional, default=0: If true then transpose the first input before dot.
  • transpose_b::boolean, optional, default=0: If true then transpose the second input before dot.
  • forward_stype::{None, 'csr', 'default', 'row_sparse'},optional, default='None': The desired storage type of the forward output given by user, if thecombination of input storage types and this hint does not matchany implemented ones, the dot operator will perform fallback operationand still produce an output of the desired storage type.

source

# MXNet.mx._backward_erfMethod.

_backward_erf(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_erfinvMethod.

_backward_erfinv(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_expm1Method.

_backward_expm1(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_foreachMethod.

_backward_foreach()

Arguments

source

# MXNet.mx._backward_gammaMethod.

_backward_gamma(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_gammalnMethod.

_backward_gammaln(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_gather_ndMethod.

_backward_gather_nd(data, indices, shape)

Accumulates data according to indices and get the result. It's the backward of gather_nd.

Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape (M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}), where M <= N. If M == N, data shape should simply be (Y_0, ..., Y_{K-1}).

The elements in output is defined as follows::

output[indices[0, y0, ..., y], ..., indices[M-1, y0, ..., y], xM, ..., x] += data[y0, ..., y, xM, ..., x]

all other entries in output are 0 or the original value if AddTo is triggered.

Examples::

data = [2, 3, 0] indices = [[1, 1, 0], [0, 1, 0]] shape = (2, 2) backwardgathernd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatternd

The difference between scatternd and scatternd_acc is the latter will accumulate

the values that point to the same index.

data = [2, 3, 0] indices = [[1, 1, 0], [1, 1, 0]] shape = (2, 2) backwardgather_nd(data, indices, shape) = [[0, 0], [0, 5]]

Arguments

  • data::NDArray-or-SymbolicNode: data
  • indices::NDArray-or-SymbolicNode: indices
  • shape::Shape(tuple), required: Shape of output.

source

# MXNet.mx._backward_hard_sigmoidMethod.

_backward_hard_sigmoid()

Arguments

source

# MXNet.mx._backward_hypotMethod.

_backward_hypot()

Arguments

source

# MXNet.mx._backward_hypot_scalarMethod.

_backward_hypot_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_im2colMethod.

_backward_im2col()

Arguments

source

# MXNet.mx._backward_image_cropMethod.

_backward_image_crop()

Arguments

source

# MXNet.mx._backward_image_normalizeMethod.

_backward_image_normalize()

Arguments

source

# MXNet.mx._backward_interleaved_matmul_encdec_qkMethod.

_backward_interleaved_matmul_encdec_qk()

Arguments

source

# MXNet.mx._backward_interleaved_matmul_encdec_valattMethod.

_backward_interleaved_matmul_encdec_valatt()

Arguments

source

# MXNet.mx._backward_interleaved_matmul_selfatt_qkMethod.

_backward_interleaved_matmul_selfatt_qk()

Arguments

source

# MXNet.mx._backward_interleaved_matmul_selfatt_valattMethod.

_backward_interleaved_matmul_selfatt_valatt()

Arguments

source

# MXNet.mx._backward_linalg_detMethod.

_backward_linalg_det()

Arguments

source

# MXNet.mx._backward_linalg_extractdiagMethod.

_backward_linalg_extractdiag()

Arguments

source

# MXNet.mx._backward_linalg_extracttrianMethod.

_backward_linalg_extracttrian()

Arguments

source

# MXNet.mx._backward_linalg_gelqfMethod.

_backward_linalg_gelqf()

Arguments

source

# MXNet.mx._backward_linalg_gemmMethod.

_backward_linalg_gemm()

Arguments

source

# MXNet.mx._backward_linalg_gemm2Method.

_backward_linalg_gemm2()

Arguments

source

# MXNet.mx._backward_linalg_inverseMethod.

_backward_linalg_inverse()

Arguments

source

# MXNet.mx._backward_linalg_makediagMethod.

_backward_linalg_makediag()

Arguments

source

# MXNet.mx._backward_linalg_maketrianMethod.

_backward_linalg_maketrian()

Arguments

source

# MXNet.mx._backward_linalg_potrfMethod.

_backward_linalg_potrf()

Arguments

source

# MXNet.mx._backward_linalg_potriMethod.

_backward_linalg_potri()

Arguments

source

# MXNet.mx._backward_linalg_slogdetMethod.

_backward_linalg_slogdet()

Arguments

source

# MXNet.mx._backward_linalg_sumlogdiagMethod.

_backward_linalg_sumlogdiag()

Arguments

source

# MXNet.mx._backward_linalg_syevdMethod.

_backward_linalg_syevd()

Arguments

source

# MXNet.mx._backward_linalg_syrkMethod.

_backward_linalg_syrk()

Arguments

source

# MXNet.mx._backward_linalg_trmmMethod.

_backward_linalg_trmm()

Arguments

source

# MXNet.mx._backward_linalg_trsmMethod.

_backward_linalg_trsm()

Arguments

source

# MXNet.mx._backward_linear_reg_outMethod.

_backward_linear_reg_out()

Arguments

source

# MXNet.mx._backward_logMethod.

_backward_log(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_log10Method.

_backward_log10(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_log1pMethod.

_backward_log1p(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_log2Method.

_backward_log2(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_log_softmaxMethod.

_backward_log_softmax(args)

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input arguments

source

# MXNet.mx._backward_logistic_reg_outMethod.

_backward_logistic_reg_out()

Arguments

source

# MXNet.mx._backward_mae_reg_outMethod.

_backward_mae_reg_out()

Arguments

source

# MXNet.mx._backward_maxMethod.

_backward_max()

Arguments

source

# MXNet.mx._backward_maximumMethod.

_backward_maximum()

Arguments

source

# MXNet.mx._backward_maximum_scalarMethod.

_backward_maximum_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_meanMethod.

_backward_mean()

Arguments

source

# MXNet.mx._backward_minMethod.

_backward_min()

Arguments

source

# MXNet.mx._backward_minimumMethod.

_backward_minimum()

Arguments

source

# MXNet.mx._backward_minimum_scalarMethod.

_backward_minimum_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_modMethod.

_backward_mod()

Arguments

source

# MXNet.mx._backward_mod_scalarMethod.

_backward_mod_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_momentsMethod.

_backward_moments()

Arguments

source

# MXNet.mx._backward_mulMethod.

_backward_mul()

Arguments

source

# MXNet.mx._backward_mul_scalarMethod.

_backward_mul_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_nanprodMethod.

_backward_nanprod()

Arguments

source

# MXNet.mx._backward_nansumMethod.

_backward_nansum()

Arguments

source

# MXNet.mx._backward_normMethod.

_backward_norm()

Arguments

source

# MXNet.mx._backward_np_averageMethod.

_backward_np_average()

Arguments

source

# MXNet.mx._backward_np_broadcast_toMethod.

_backward_np_broadcast_to()

Arguments

source

# MXNet.mx._backward_np_column_stackMethod.

_backward_np_column_stack()

Arguments

source

# MXNet.mx._backward_np_concatMethod.

_backward_np_concat()

Arguments

source

# MXNet.mx._backward_np_cumsumMethod.

_backward_np_cumsum()

Arguments

source

# MXNet.mx._backward_np_diagMethod.

_backward_np_diag()

Arguments

source

# MXNet.mx._backward_np_diagflatMethod.

_backward_np_diagflat()

Arguments

source

# MXNet.mx._backward_np_diagonalMethod.

_backward_np_diagonal()

Arguments

source

# MXNet.mx._backward_np_dotMethod.

_backward_np_dot()

Arguments

source

# MXNet.mx._backward_np_dstackMethod.

_backward_np_dstack()

Arguments

source

# MXNet.mx._backward_np_hstackMethod.

_backward_np_hstack()

Arguments

source

# MXNet.mx._backward_np_matmulMethod.

_backward_np_matmul()

Arguments

source

# MXNet.mx._backward_np_maxMethod.

_backward_np_max()

Arguments

source

# MXNet.mx._backward_np_meanMethod.

_backward_np_mean()

Arguments

source

# MXNet.mx._backward_np_minMethod.

_backward_np_min()

Arguments

source

# MXNet.mx._backward_np_prodMethod.

_backward_np_prod()

Arguments

source

# MXNet.mx._backward_np_sumMethod.

_backward_np_sum()

Arguments

source

# MXNet.mx._backward_np_traceMethod.

_backward_np_trace()

Arguments

source

# MXNet.mx._backward_np_vstackMethod.

_backward_np_vstack()

Arguments

source

# MXNet.mx._backward_np_whereMethod.

_backward_np_where()

Arguments

source

# MXNet.mx._backward_np_where_lscalarMethod.

_backward_np_where_lscalar()

Arguments

source

# MXNet.mx._backward_np_where_rscalarMethod.

_backward_np_where_rscalar()

Arguments

source

# MXNet.mx._backward_npi_arctan2Method.

_backward_npi_arctan2()

Arguments

source

# MXNet.mx._backward_npi_arctan2_scalarMethod.

_backward_npi_arctan2_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_npi_broadcast_addMethod.

_backward_npi_broadcast_add()

Arguments

source

# MXNet.mx._backward_npi_broadcast_divMethod.

_backward_npi_broadcast_div()

Arguments

source

# MXNet.mx._backward_npi_broadcast_modMethod.

_backward_npi_broadcast_mod()

Arguments

source

# MXNet.mx._backward_npi_broadcast_mulMethod.

_backward_npi_broadcast_mul()

Arguments

source

# MXNet.mx._backward_npi_broadcast_powerMethod.

_backward_npi_broadcast_power()

Arguments

source

# MXNet.mx._backward_npi_broadcast_subMethod.

_backward_npi_broadcast_sub()

Arguments

source

# MXNet.mx._backward_npi_copysignMethod.

_backward_npi_copysign()

Arguments

source

# MXNet.mx._backward_npi_copysign_scalarMethod.

_backward_npi_copysign_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_npi_diffMethod.

_backward_npi_diff()

Arguments

source

# MXNet.mx._backward_npi_einsumMethod.

_backward_npi_einsum()

Arguments

source

# MXNet.mx._backward_npi_flipMethod.

_backward_npi_flip()

Arguments

source

# MXNet.mx._backward_npi_hypotMethod.

_backward_npi_hypot()

Arguments

source

# MXNet.mx._backward_npi_ldexpMethod.

_backward_npi_ldexp()

Arguments

source

# MXNet.mx._backward_npi_ldexp_scalarMethod.

_backward_npi_ldexp_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_npi_normMethod.

_backward_npi_norm()

Arguments

source

# MXNet.mx._backward_npi_padMethod.

_backward_npi_pad()

Arguments

source

# MXNet.mx._backward_npi_rarctan2_scalarMethod.

_backward_npi_rarctan2_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_npi_rcopysign_scalarMethod.

_backward_npi_rcopysign_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_npi_rldexp_scalarMethod.

_backward_npi_rldexp_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_npi_solveMethod.

_backward_npi_solve()

Arguments

source

# MXNet.mx._backward_npi_svdMethod.

_backward_npi_svd()

Arguments

source

# MXNet.mx._backward_npi_tensordotMethod.

_backward_npi_tensordot()

Arguments

source

# MXNet.mx._backward_npi_tensordot_int_axesMethod.

_backward_npi_tensordot_int_axes()

Arguments

source

# MXNet.mx._backward_npi_tensorinvMethod.

_backward_npi_tensorinv()

Arguments

source

# MXNet.mx._backward_npi_tensorsolveMethod.

_backward_npi_tensorsolve()

Arguments

source

# MXNet.mx._backward_pdf_dirichletMethod.

_backward_pdf_dirichlet()

Arguments

source

# MXNet.mx._backward_pdf_exponentialMethod.

_backward_pdf_exponential()

Arguments

source

# MXNet.mx._backward_pdf_gammaMethod.

_backward_pdf_gamma()

Arguments

source

# MXNet.mx._backward_pdf_generalized_negative_binomialMethod.

_backward_pdf_generalized_negative_binomial()

Arguments

source

# MXNet.mx._backward_pdf_negative_binomialMethod.

_backward_pdf_negative_binomial()

Arguments

source

# MXNet.mx._backward_pdf_normalMethod.

_backward_pdf_normal()

Arguments

source

# MXNet.mx._backward_pdf_poissonMethod.

_backward_pdf_poisson()

Arguments

source

# MXNet.mx._backward_pdf_uniformMethod.

_backward_pdf_uniform()

Arguments

source

# MXNet.mx._backward_pickMethod.

_backward_pick()

Arguments

source

# MXNet.mx._backward_powerMethod.

_backward_power()

Arguments

source

# MXNet.mx._backward_power_scalarMethod.

_backward_power_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_prodMethod.

_backward_prod()

Arguments

source

# MXNet.mx._backward_radiansMethod.

_backward_radians(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_rcbrtMethod.

_backward_rcbrt(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_rdiv_scalarMethod.

_backward_rdiv_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_reciprocalMethod.

_backward_reciprocal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_reluMethod.

_backward_relu(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_repeatMethod.

_backward_repeat()

Arguments

source

# MXNet.mx._backward_reshapeMethod.

_backward_reshape()

Arguments

source

# MXNet.mx._backward_reverseMethod.

_backward_reverse()

Arguments

source

# MXNet.mx._backward_rmod_scalarMethod.

_backward_rmod_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_rpower_scalarMethod.

_backward_rpower_scalar(lhs, rhs, scalar, is_int)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._backward_rsqrtMethod.

_backward_rsqrt(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_sample_multinomialMethod.

_backward_sample_multinomial()

Arguments

source

# MXNet.mx._backward_sigmoidMethod.

_backward_sigmoid(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_signMethod.

_backward_sign(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_sinMethod.

_backward_sin(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_sinhMethod.

_backward_sinh(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_sliceMethod.

_backward_slice()

Arguments

source

# MXNet.mx._backward_slice_axisMethod.

_backward_slice_axis()

Arguments

source

# MXNet.mx._backward_slice_likeMethod.

_backward_slice_like()

Arguments

source

# MXNet.mx._backward_smooth_l1Method.

_backward_smooth_l1(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_softmaxMethod.

_backward_softmax(args)

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input arguments

source

# MXNet.mx._backward_softmax_cross_entropyMethod.

_backward_softmax_cross_entropy()

Arguments

source

# MXNet.mx._backward_softminMethod.

_backward_softmin(args)

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input arguments

source

# MXNet.mx._backward_softsignMethod.

_backward_softsign(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_sparse_retainMethod.

_backward_sparse_retain()

Arguments

source

# MXNet.mx._backward_sqrtMethod.

_backward_sqrt(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_squareMethod.

_backward_square(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_square_sumMethod.

_backward_square_sum()

Arguments

source

# MXNet.mx._backward_squeezeMethod.

_backward_squeeze()

Arguments

source

# MXNet.mx._backward_stackMethod.

_backward_stack()

Arguments

source

# MXNet.mx._backward_subMethod.

_backward_sub()

Arguments

source

# MXNet.mx._backward_sumMethod.

_backward_sum()

Arguments

source

# MXNet.mx._backward_takeMethod.

_backward_take()

Arguments

source

# MXNet.mx._backward_tanMethod.

_backward_tan(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_tanhMethod.

_backward_tanh(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._backward_tileMethod.

_backward_tile()

Arguments

source

# MXNet.mx._backward_topkMethod.

_backward_topk()

Arguments

source

# MXNet.mx._backward_trilMethod.

_backward_tril()

Arguments

source

# MXNet.mx._backward_whereMethod.

_backward_where()

Arguments

source

# MXNet.mx._backward_while_loopMethod.

_backward_while_loop()

Arguments

source

# MXNet.mx._broadcast_backwardMethod.

_broadcast_backward()

Arguments

source

# MXNet.mx._condMethod.

_cond(cond, then_branch, else_branch, data, num_args, num_outputs, cond_input_locs, then_input_locs, else_input_locs)

Note: cond takes variable number of positional inputs. So instead of calling as _cond([x, y, z], numargs=3), one should call via cond(x, y, z), and numargs will be determined automatically.

Run a if-then-else using user-defined condition and computation

From:src/operator/control_flow.cc:1212

Arguments

  • cond::SymbolicNode: Input graph for the condition.
  • then_branch::SymbolicNode: Input graph for the then branch.
  • else_branch::SymbolicNode: Input graph for the else branch.
  • data::NDArray-or-SymbolicNode[]: The input arrays that include data arrays and states.
  • num_args::int, required: Number of input arguments, including cond, then and else as three symbol inputs.
  • num_outputs::int, required: The number of outputs of the subgraph.
  • cond_input_locs::tuple of <long>, required: The locations of cond's inputs in the given inputs.
  • then_input_locs::tuple of <long>, required: The locations of then's inputs in the given inputs.
  • else_input_locs::tuple of <long>, required: The locations of else's inputs in the given inputs.

source

# MXNet.mx._contrib_AdaptiveAvgPooling2DMethod.

_contrib_AdaptiveAvgPooling2D(data, output_size)

Applies a 2D adaptive average pooling over a 4D input with the shape of (NCHW). The pooling kernel and stride sizes are automatically chosen for desired output sizes.

  • If a single integer is provided for outputsize, the output size is (N x C x outputsize x output_size) for any input (NCHW).
  • If a tuple of integers (height, width) are provided for output_size, the output size is (N x C x height x width) for any input (NCHW).

Defined in src/operator/contrib/adaptiveavgpooling.cc:L213

Arguments

  • data::NDArray-or-SymbolicNode: Input data
  • output_size::Shape(tuple), optional, default=[]: int (output size) or a tuple of int for output (height, width).

source

# MXNet.mx._contrib_BatchNormWithReLUMethod.

_contrib_BatchNormWithReLU(data, gamma, beta, moving_mean, moving_var, eps, momentum, fix_gamma, use_global_stats, output_mean_var, axis, cudnn_off, min_calib_range, max_calib_range)

Batch normalization with ReLU fusion.

An extented operator of Batch normalization which can fuse ReLU activation.

Defined in src/operator/contrib/batchnormrelu.cc:L249

Arguments

  • data::NDArray-or-SymbolicNode: Input data to batch normalization
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • moving_mean::NDArray-or-SymbolicNode: running mean of input
  • moving_var::NDArray-or-SymbolicNode: running variance of input
  • eps::double, optional, default=0.0010000000474974513: Epsilon to prevent div 0. Must be no less than CUDNNBNMIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5)
  • momentum::float, optional, default=0.899999976: Momentum for moving average
  • fix_gamma::boolean, optional, default=1: Fix gamma while training
  • use_global_stats::boolean, optional, default=0: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator.
  • output_mean_var::boolean, optional, default=0: Output the mean and inverse std
  • axis::int, optional, default='1': Specify which shape axis the channel is specified
  • cudnn_off::boolean, optional, default=0: Do not select CUDNN operator, if available
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.

source

# MXNet.mx._contrib_BilinearResize2DMethod.

_contrib_BilinearResize2D(data, like, height, width, scale_height, scale_width, mode, align_corners)

Perform 2D resizing (upsampling or downsampling) for 4D input using bilinear interpolation.

Expected input is a 4 dimensional NDArray (NCHW) and the output with the shape of (N x C x height x width). The key idea of bilinear interpolation is to perform linear interpolation first in one direction, and then again in the other direction. See the wikipedia of Bilinear interpolation <https://en.wikipedia.org/wiki/Bilinear_interpolation>_ for more details.

Defined in src/operator/contrib/bilinear_resize.cc:L219

Arguments

  • data::NDArray-or-SymbolicNode: Input data
  • like::NDArray-or-SymbolicNode: Resize data to it's shape
  • height::int, optional, default='1': output height (required, but ignored if scale_height is defined or mode is not "size")
  • width::int, optional, default='1': output width (required, but ignored if scale_width is defined or mode is not "size")
  • scale_height::float or None, optional, default=None: sampling scale of the height (optional, used in modes "scale" and "odd_scale")
  • scale_width::float or None, optional, default=None: sampling scale of the width (optional, used in modes "scale" and "odd_scale")
  • mode::{'like', 'odd_scale', 'size', 'to_even_down', 'to_even_up', 'to_odd_down', 'to_odd_up'},optional, default='size': resizing mode. "simple" - output height equals parameter "height" if "scaleheight" parameter is not defined or input height multiplied by "scaleheight" otherwise. Same for width;"oddscale" - if original height or width is odd, then result height is calculated like resulth = (originalh - 1) * scale + 1; for scale > 1 the result shape would be like if we did deconvolution with kernel = (1, 1) and stride = (heightscale, widthscale); and for scale < 1 shape would be like we did convolution with kernel = (1, 1) and stride = (int(1 / heightscale), int( 1/ widthscale);"like" - resize first input to the height and width of second input; "toevendown" - resize input to nearest lower even height and width (if original height is odd then result height = original height - 1);"toevenup" - resize input to nearest bigger even height and width (if original height is odd then result height = original height + 1);"toodddown" - resize input to nearest odd height and width (if original height is odd then result height = original height - 1);"toodd_up" - resize input to nearest odd height and width (if original height is odd then result height = original height + 1);
  • align_corners::boolean, optional, default=1: With align_corners = True, the interpolating doesn't proportionally align theoutput and input pixels, and thus the output values can depend on the input size.

source

# MXNet.mx._contrib_CTCLossMethod.

_contrib_CTCLoss(data, label, data_lengths, label_lengths, use_data_lengths, use_label_lengths, blank_label)

contribCTCLoss is an alias of CTCLoss.

Connectionist Temporal Classification Loss.

.. note:: The existing alias $contrib_CTCLoss$ is deprecated.

The shapes of the inputs and outputs:

  • data: (sequence_length, batch_size, alphabet_size)
  • label: (batch_size, label_sequence_length)
  • out: (batch_size)

The data tensor consists of sequences of activation vectors (without applying softmax), with i-th channel in the last dimension corresponding to i-th label for i between 0 and alphabet*size-1 (i.e always 0-indexed). Alphabet size should include one additional value reserved for blank label. When blank*labelis"first", the0-th channel is be reserved for activation of blank label, or otherwise if it is "last",(alphabet_size-1)-th channel should be reserved for blank label.

$label$ is an index matrix of integers. When blank_label is $"first"$, the value 0 is then reserved for blank label, and should not be passed in this matrix. Otherwise, when blank_label is $"last"$, the value (alphabet_size-1) is reserved for blank label.

If a sequence of labels is shorter than labelsequencelength, use the special padding value at the end of the sequence to conform it to the correct length. The padding value is 0 when blank_label is $"first"$, and -1 otherwise.

For example, suppose the vocabulary is [a, b, c], and in one batch we have three sequences 'ba', 'cbb', and 'abac'. When blank_label is $"first"$, we can index the labels as {'a': 1, 'b': 2, 'c': 3}, and we reserve the 0-th channel for blank label in data tensor. The resulting label tensor should be padded to be::

[[2, 1, 0, 0], [3, 2, 2, 0], [1, 2, 1, 3]]

When blank_label is $"last"$, we can index the labels as {'a': 0, 'b': 1, 'c': 2}, and we reserve the channel index 3 for blank label in data tensor. The resulting label tensor should be padded to be::

[[1, 0, -1, -1], [2, 1, 1, -1], [0, 1, 0, 2]]

$out$ is a list of CTC loss values, one per example in the batch.

See Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks, A. Graves et al. for more information on the definition and the algorithm.

Defined in src/operator/nn/ctc_loss.cc:L100

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • label::NDArray-or-SymbolicNode: Ground-truth labels for the loss.
  • data_lengths::NDArray-or-SymbolicNode: Lengths of data for each of the samples. Only required when usedatalengths is true.
  • label_lengths::NDArray-or-SymbolicNode: Lengths of labels for each of the samples. Only required when uselabellengths is true.
  • use_data_lengths::boolean, optional, default=0: Whether the data lenghts are decided by data_lengths. If false, the lengths are equal to the max sequence length.
  • use_label_lengths::boolean, optional, default=0: Whether the label lenghts are decided by label_lengths, or derived from padding_mask. If false, the lengths are derived from the first occurrence of the value of padding_mask. The value of padding_mask is $0$ when first CTC label is reserved for blank, and $-1$ when last label is reserved for blank. See blank_label.
  • blank_label::{'first', 'last'},optional, default='first': Set the label that is reserved for blank label.If "first", 0-th label is reserved, and label values for tokens in the vocabulary are between $1$ and $alphabet_size-1$, and the padding mask is $-1$. If "last", last label value $alphabet_size-1$ is reserved for blank label instead, and label values for tokens in the vocabulary are between $0$ and $alphabet_size-2$, and the padding mask is $0$.

source

# MXNet.mx._contrib_DeformableConvolutionMethod.

_contrib_DeformableConvolution(data, offset, weight, bias, kernel, stride, dilate, pad, num_filter, num_group, num_deformable_group, workspace, no_bias, layout)

Compute 2-D deformable convolution on 4-D input.

The deformable convolution operation is described in https://arxiv.org/abs/1703.06211

For 2-D deformable convolution, the shapes are

  • data: (batch_size, channel, height, width)
  • offset: (batchsize, numdeformable_group * kernel[0] * kernel[1] * 2, height, width)
  • weight: (num_filter, channel, kernel[0], kernel[1])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, outheight, outwidth).

Define::

f(x,k,p,s,d) = floor((x+2p-d(k-1)-1)/s)+1

then we have::

outheight=f(height, kernel[0], pad[0], stride[0], dilate[0]) outwidth=f(width, kernel[1], pad[1], stride[1], dilate[1])

If $no_bias$ is set to be true, then the $bias$ term is ignored.

The default data $layout$ is NCHW, namely (batch_size, channle, height, width).

If $num_group$ is larger than 1, denoted by g, then split the input $data$ evenly into g parts along the channel axis, and also evenly split $weight$ along the first dimension. Next compute the convolution on the i-th part of the data with the i-th weight part. The output is obtained by concating all the g results.

If $num_deformable_group$ is larger than 1, denoted by dg, then split the input $offset$ evenly into dg parts along the channel axis, and also evenly split $data$ into dg parts along the channel axis. Next compute the deformable convolution, apply the i-th part of the offset on the i-th part of the data.

Both $weight$ and $bias$ are learnable parameters.

Defined in src/operator/contrib/deformable_convolution.cc:L83

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the DeformableConvolutionOp.
  • offset::NDArray-or-SymbolicNode: Input offset to the DeformableConvolutionOp.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • kernel::Shape(tuple), required: Convolution kernel size: (h, w) or (d, h, w)
  • stride::Shape(tuple), optional, default=[]: Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding.
  • num_filter::int, required: Convolution filter(channel) number
  • num_group::int, optional, default='1': Number of group partitions.
  • num_deformable_group::int, optional, default='1': Number of deformable group partitions.
  • workspace::long (non-negative), optional, default=1024: Maximum temperal workspace allowed for convolution (MB).
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW'},optional, default='None': Set layout for input, output and weight. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.

source

# MXNet.mx._contrib_DeformablePSROIPoolingMethod.

_contrib_DeformablePSROIPooling(data, rois, trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std, no_trans)

Performs deformable position-sensitive region-of-interest pooling on inputs. The DeformablePSROIPooling operation is described in https://arxiv.org/abs/1703.06211 .batch_size will change to the number of region bounding boxes after DeformablePSROIPooling

Arguments

  • data::SymbolicNode: Input data to the pooling operator, a 4D Feature maps
  • rois::SymbolicNode: Bounding box coordinates, a 2D array of [[batchindex, x1, y1, x2, y2]]. (x1, y1) and (x2, y2) are top left and down right corners of designated region of interest. batchindex indicates the index of corresponding image in the input data
  • trans::SymbolicNode: transition parameter
  • spatial_scale::float, required: Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers
  • output_dim::int, required: fix output dim
  • group_size::int, required: fix group size
  • pooled_size::int, required: fix pooled size
  • part_size::int, optional, default='0': fix part size
  • sample_per_part::int, optional, default='1': fix samples per part
  • trans_std::float, optional, default=0: fix transition std
  • no_trans::boolean, optional, default=0: Whether to disable trans parameter.

source

# MXNet.mx._contrib_ModulatedDeformableConvolutionMethod.

_contrib_ModulatedDeformableConvolution(data, offset, mask, weight, bias, kernel, stride, dilate, pad, num_filter, num_group, num_deformable_group, workspace, no_bias, im2col_step, layout)

Compute 2-D modulated deformable convolution on 4-D input.

The modulated deformable convolution operation is described in https://arxiv.org/abs/1811.11168

For 2-D modulated deformable convolution, the shapes are

  • data: (batch_size, channel, height, width)
  • offset: (batchsize, numdeformable_group * kernel[0] * kernel[1] * 2, height, width)
  • mask: (batchsize, numdeformable_group * kernel[0] * kernel[1], height, width)
  • weight: (num_filter, channel, kernel[0], kernel[1])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, outheight, outwidth).

Define::

f(x,k,p,s,d) = floor((x+2p-d(k-1)-1)/s)+1

then we have::

outheight=f(height, kernel[0], pad[0], stride[0], dilate[0]) outwidth=f(width, kernel[1], pad[1], stride[1], dilate[1])

If $no_bias$ is set to be true, then the $bias$ term is ignored.

The default data $layout$ is NCHW, namely (batch_size, channle, height, width).

If $num_group$ is larger than 1, denoted by g, then split the input $data$ evenly into g parts along the channel axis, and also evenly split $weight$ along the first dimension. Next compute the convolution on the i-th part of the data with the i-th weight part. The output is obtained by concating all the g results.

If $num_deformable_group$ is larger than 1, denoted by dg, then split the input $offset$ evenly into dg parts along the channel axis, and also evenly split $out$ evenly into dg parts along the channel axis. Next compute the deformable convolution, apply the i-th part of the offset part on the i-th out.

Both $weight$ and $bias$ are learnable parameters.

Defined in src/operator/contrib/modulateddeformableconvolution.cc:L83

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the ModulatedDeformableConvolutionOp.
  • offset::NDArray-or-SymbolicNode: Input offset to ModulatedDeformableConvolutionOp.
  • mask::NDArray-or-SymbolicNode: Input mask to the ModulatedDeformableConvolutionOp.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • kernel::Shape(tuple), required: Convolution kernel size: (h, w) or (d, h, w)
  • stride::Shape(tuple), optional, default=[]: Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding.
  • num_filter::int (non-negative), required: Convolution filter(channel) number
  • num_group::int (non-negative), optional, default=1: Number of group partitions.
  • num_deformable_group::int (non-negative), optional, default=1: Number of deformable group partitions.
  • workspace::long (non-negative), optional, default=1024: Maximum temperal workspace allowed for convolution (MB).
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • im2col_step::int (non-negative), optional, default=64: Maximum number of images per im2col computation; The total batch size should be divisable by this value or smaller than this value; if you face out of memory problem, you can try to use a smaller value here.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW'},optional, default='None': Set layout for input, output and weight. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.

source

# MXNet.mx._contrib_MultiBoxDetectionMethod.

_contrib_MultiBoxDetection(cls_prob, loc_pred, anchor, clip, threshold, background_id, nms_threshold, force_suppress, variances, nms_topk)

Convert multibox detection predictions.

Arguments

  • cls_prob::NDArray-or-SymbolicNode: Class probabilities.
  • loc_pred::NDArray-or-SymbolicNode: Location regression predictions.
  • anchor::NDArray-or-SymbolicNode: Multibox prior anchor boxes
  • clip::boolean, optional, default=1: Clip out-of-boundary boxes.
  • threshold::float, optional, default=0.00999999978: Threshold to be a positive prediction.
  • background_id::int, optional, default='0': Background id.
  • nms_threshold::float, optional, default=0.5: Non-maximum suppression threshold.
  • force_suppress::boolean, optional, default=0: Suppress all detections regardless of class_id.
  • variances::tuple of <float>, optional, default=[0.1,0.1,0.2,0.2]: Variances to be decoded from box regression output.
  • nms_topk::int, optional, default='-1': Keep maximum top k detections before nms, -1 for no limit.

source

# MXNet.mx._contrib_MultiBoxPriorMethod.

_contrib_MultiBoxPrior(data, sizes, ratios, clip, steps, offsets)

Generate prior(anchor) boxes from data, sizes and ratios.

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • sizes::tuple of <float>, optional, default=[1]: List of sizes of generated MultiBoxPriores.
  • ratios::tuple of <float>, optional, default=[1]: List of aspect ratios of generated MultiBoxPriores.
  • clip::boolean, optional, default=0: Whether to clip out-of-boundary boxes.
  • steps::tuple of <float>, optional, default=[-1,-1]: Priorbox step across y and x, -1 for auto calculation.
  • offsets::tuple of <float>, optional, default=[0.5,0.5]: Priorbox center offsets, y and x respectively

source

# MXNet.mx._contrib_MultiBoxTargetMethod.

_contrib_MultiBoxTarget(anchor, label, cls_pred, overlap_threshold, ignore_label, negative_mining_ratio, negative_mining_thresh, minimum_negative_samples, variances)

Compute Multibox training targets

Arguments

  • anchor::NDArray-or-SymbolicNode: Generated anchor boxes.
  • label::NDArray-or-SymbolicNode: Object detection labels.
  • cls_pred::NDArray-or-SymbolicNode: Class predictions.
  • overlap_threshold::float, optional, default=0.5: Anchor-GT overlap threshold to be regarded as a positive match.
  • ignore_label::float, optional, default=-1: Label for ignored anchors.
  • negative_mining_ratio::float, optional, default=-1: Max negative to positive samples ratio, use -1 to disable mining
  • negative_mining_thresh::float, optional, default=0.5: Threshold used for negative mining.
  • minimum_negative_samples::int, optional, default='0': Minimum number of negative samples.
  • variances::tuple of <float>, optional, default=[0.1,0.1,0.2,0.2]: Variances to be encoded in box regression target.

source

# MXNet.mx._contrib_MultiProposalMethod.

_contrib_MultiProposal(cls_prob, bbox_pred, im_info, rpn_pre_nms_top_n, rpn_post_nms_top_n, threshold, rpn_min_size, scales, ratios, feature_stride, output_score, iou_loss)

Generate region proposals via RPN

Arguments

  • cls_prob::NDArray-or-SymbolicNode: Score of how likely proposal is object.
  • bbox_pred::NDArray-or-SymbolicNode: BBox Predicted deltas from anchors for proposals
  • im_info::NDArray-or-SymbolicNode: Image size and scale.
  • rpn_pre_nms_top_n::int, optional, default='6000': Number of top scoring boxes to keep before applying NMS to RPN proposals
  • rpn_post_nms_top_n::int, optional, default='300': Number of top scoring boxes to keep after applying NMS to RPN proposals
  • threshold::float, optional, default=0.699999988: NMS value, below which to suppress.
  • rpn_min_size::int, optional, default='16': Minimum height or width in proposal
  • scales::tuple of <float>, optional, default=[4,8,16,32]: Used to generate anchor windows by enumerating scales
  • ratios::tuple of <float>, optional, default=[0.5,1,2]: Used to generate anchor windows by enumerating ratios
  • feature_stride::int, optional, default='16': The size of the receptive field each unit in the convolution layer of the rpn,for example the product of all stride's prior to this layer.
  • output_score::boolean, optional, default=0: Add score to outputs
  • iou_loss::boolean, optional, default=0: Usage of IoU Loss

source

# MXNet.mx._contrib_PSROIPoolingMethod.

_contrib_PSROIPooling(data, rois, spatial_scale, output_dim, pooled_size, group_size)

Performs region-of-interest pooling on inputs. Resize bounding box coordinates by spatialscale and crop input feature maps accordingly. The cropped feature maps are pooled by max pooling to a fixed size output indicated by pooledsize. batch_size will change to the number of region bounding boxes after PSROIPooling

Arguments

  • data::SymbolicNode: Input data to the pooling operator, a 4D Feature maps
  • rois::SymbolicNode: Bounding box coordinates, a 2D array of [[batchindex, x1, y1, x2, y2]]. (x1, y1) and (x2, y2) are top left and down right corners of designated region of interest. batchindex indicates the index of corresponding image in the input data
  • spatial_scale::float, required: Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers
  • output_dim::int, required: fix output dim
  • pooled_size::int, required: fix pooled size
  • group_size::int, optional, default='0': fix group size

source

# MXNet.mx._contrib_ProposalMethod.

_contrib_Proposal(cls_prob, bbox_pred, im_info, rpn_pre_nms_top_n, rpn_post_nms_top_n, threshold, rpn_min_size, scales, ratios, feature_stride, output_score, iou_loss)

Generate region proposals via RPN

Arguments

  • cls_prob::NDArray-or-SymbolicNode: Score of how likely proposal is object.
  • bbox_pred::NDArray-or-SymbolicNode: BBox Predicted deltas from anchors for proposals
  • im_info::NDArray-or-SymbolicNode: Image size and scale.
  • rpn_pre_nms_top_n::int, optional, default='6000': Number of top scoring boxes to keep before applying NMS to RPN proposals
  • rpn_post_nms_top_n::int, optional, default='300': Number of top scoring boxes to keep after applying NMS to RPN proposals
  • threshold::float, optional, default=0.699999988: NMS value, below which to suppress.
  • rpn_min_size::int, optional, default='16': Minimum height or width in proposal
  • scales::tuple of <float>, optional, default=[4,8,16,32]: Used to generate anchor windows by enumerating scales
  • ratios::tuple of <float>, optional, default=[0.5,1,2]: Used to generate anchor windows by enumerating ratios
  • feature_stride::int, optional, default='16': The size of the receptive field each unit in the convolution layer of the rpn,for example the product of all stride's prior to this layer.
  • output_score::boolean, optional, default=0: Add score to outputs
  • iou_loss::boolean, optional, default=0: Usage of IoU Loss

source

# MXNet.mx._contrib_ROIAlignMethod.

_contrib_ROIAlign(data, rois, pooled_size, spatial_scale, sample_ratio, position_sensitive, aligned)

This operator takes a 4D feature map as an input array and region proposals as rois, then align the feature map over sub-regions of input and produces a fixed-sized output array. This operator is typically used in Faster R-CNN & Mask R-CNN networks. If roi batchid is less than 0, it will be ignored, and the corresponding output will be set to 0.

Different from ROI pooling, ROI Align removes the harsh quantization, properly aligning the extracted features with the input. RoIAlign computes the value of each sampling point by bilinear interpolation from the nearby grid points on the feature map. No quantization is performed on any coordinates involved in the RoI, its bins, or the sampling points. Bilinear interpolation is used to compute the exact values of the input features at four regularly sampled locations in each RoI bin. Then the feature map can be aggregated by avgpooling.

References

He, Kaiming, et al. "Mask R-CNN." ICCV, 2017

Defined in src/operator/contrib/roi_align.cc:L558

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the pooling operator, a 4D Feature maps
  • rois::NDArray-or-SymbolicNode: Bounding box coordinates, a 2D array, if batchid is less than 0, it will be ignored.
  • pooled_size::Shape(tuple), required: ROI Align output roi feature map height and width: (h, w)
  • spatial_scale::float, required: Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers
  • sample_ratio::int, optional, default='-1': Optional sampling ratio of ROI align, using adaptive size by default.
  • position_sensitive::boolean, optional, default=0: Whether to perform position-sensitive RoI pooling. PSRoIPooling is first proposaled by R-FCN and it can reduce the input channels by ph*pw times, where (ph, pw) is the pooled_size
  • aligned::boolean, optional, default=0: Center-aligned ROIAlign introduced in Detectron2. To enable, set aligned to True.

source

# MXNet.mx._contrib_RROIAlignMethod.

_contrib_RROIAlign(data, rois, pooled_size, spatial_scale, sampling_ratio)

Performs Rotated ROI Align on the input array.

This operator takes a 4D feature map as an input array and region proposals as rois, then align the feature map over sub-regions of input and produces a fixed-sized output array.

Different from ROI Align, RROI Align uses rotated rois, which is suitable for text detection. RRoIAlign computes the value of each sampling point by bilinear interpolation from the nearby grid points on the rotated feature map. No quantization is performed on any coordinates involved in the RoI, its bins, or the sampling points. Bilinear interpolation is used to compute the exact values of the input features at four regularly sampled locations in each RoI bin. Then the feature map can be aggregated by avgpooling.

References

Ma, Jianqi, et al. "Arbitrary-Oriented Scene Text Detection via Rotation Proposals." IEEE Transactions on Multimedia, 2018.

Defined in src/operator/contrib/rroi_align.cc:L273

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the pooling operator, a 4D Feature maps
  • rois::NDArray-or-SymbolicNode: Bounding box coordinates, a 2D array
  • pooled_size::Shape(tuple), required: RROI align output shape (h,w)
  • spatial_scale::float, required: Ratio of input feature map height (or width) to raw image height (or width). Equals the reciprocal of total stride in convolutional layers
  • sampling_ratio::int, optional, default='-1': Optional sampling ratio of RROI align, using adaptive size by default.

source

# MXNet.mx._contrib_SparseEmbeddingMethod.

_contrib_SparseEmbedding(data, weight, input_dim, output_dim, dtype, sparse_grad)

Maps integer indices to vector representations (embeddings).

note:: $contrib.SparseEmbedding$ is deprecated, use $Embedding$ instead.

This operator maps words to real-valued vectors in a high-dimensional space, called word embeddings. These embeddings can capture semantic and syntactic properties of the words. For example, it has been noted that in the learned embedding spaces, similar words tend to be close to each other and dissimilar words far apart.

For an input array of shape (d1, ..., dK), the shape of an output array is (d1, ..., dK, outputdim). All the input values should be integers in the range [0, inputdim).

If the inputdim is ip0 and outputdim is op0, then shape of the embedding weight matrix must be (ip0, op0).

The storage type of the gradient will be row_sparse.

.. Note::

`SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k).
The operator is available on both CPU and GPU.
When `deterministic` is set to `True`, the accumulation of gradients follows a
deterministic order if a feature appears multiple times in the input. However, the
accumulation is usually slower when the order is enforced on GPU.
When the operator is used on the GPU, the recommended value for `deterministic` is `True`.

Examples::

inputdim = 4 outputdim = 5

// Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3) y = [[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [ 10., 11., 12., 13., 14.], [ 15., 16., 17., 18., 19.]]

// Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)] x = [[ 1., 3.], [ 0., 2.]]

// Mapped input x to its vector representation y. SparseEmbedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.], [ 15., 16., 17., 18., 19.]],

                            [[  0.,   1.,   2.,   3.,   4.],
                             [ 10.,  11.,  12.,  13.,  14.]]]

Defined in src/operator/tensor/indexing_op.cc:L674

Arguments

  • data::NDArray-or-SymbolicNode: The input array to the embedding operator.
  • weight::NDArray-or-SymbolicNode: The embedding weight matrix.
  • input_dim::int, required: Vocabulary size of the input indices.
  • output_dim::int, required: Dimension of the embedding vectors.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data type of weight.
  • sparse_grad::boolean, optional, default=0: Compute row sparse gradient in the backward calculation. If set to True, the grad's storage type is row_sparse.

source

# MXNet.mx._contrib_SyncBatchNormMethod.

_contrib_SyncBatchNorm(data, gamma, beta, moving_mean, moving_var, eps, momentum, fix_gamma, use_global_stats, output_mean_var, ndev, key)

Batch normalization.

Normalizes a data batch by mean and variance, and applies a scale $gamma$ as well as offset $beta$. Standard BN [1] implementation only normalize the data within each device. SyncBN normalizes the input within the whole mini-batch. We follow the sync-onece implmentation described in the paper [2].

Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis:

.. math::

data_mean[i] = mean(data[:,i,:,...]) \ data_var[i] = var(data[:,i,:,...])

Then compute the normalized output, which has the same shape as input, as following:

.. math::

out[:,i,:,...] = \frac{data[:,i,:,...] - data_mean[i]}{\sqrt{data_var[i]+\epsilon}} * gamma[i] + beta[i]

Both mean and var returns a scalar by treating the input as a vector.

Assume the input has size k on axis 1, then both $gamma$ and $beta$ have shape (k,). If $output_mean_var$ is set to be true, then outputs both $data_mean$ and $data_var$ as well, which are needed for the backward pass.

Besides the inputs and the outputs, this operator accepts two auxiliary states, $moving_mean$ and $moving_var$, which are k-length vectors. They are global statistics for the whole dataset, which are updated by::

movingmean = movingmean * momentum + datamean * (1 - momentum) movingvar = movingvar * momentum + datavar * (1 - momentum)

If $use_global_stats$ is set to be true, then $moving_mean$ and $moving_var$ are used instead of $data_mean$ and $data_var$ to compute the output. It is often used during inference.

Both $gamma$ and $beta$ are learnable parameters. But if $fix_gamma$ is true, then set $gamma$ to 1 and its gradient to 0.

Reference: .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." ICML 2015 .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." CVPR 2018

Defined in src/operator/contrib/syncbatchnorm.cc:L96

Arguments

  • data::NDArray-or-SymbolicNode: Input data to batch normalization
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • moving_mean::NDArray-or-SymbolicNode: running mean of input
  • moving_var::NDArray-or-SymbolicNode: running variance of input
  • eps::float, optional, default=0.00100000005: Epsilon to prevent div 0
  • momentum::float, optional, default=0.899999976: Momentum for moving average
  • fix_gamma::boolean, optional, default=1: Fix gamma while training
  • use_global_stats::boolean, optional, default=0: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator.
  • output_mean_var::boolean, optional, default=0: Output All,normal mean and var
  • ndev::int, optional, default='1': The count of GPU devices
  • key::string, required: Hash key for synchronization, please set the same hash key for same layer, Block.prefix is typically used as in :class:gluon.nn.contrib.SyncBatchNorm.

source

# MXNet.mx._contrib_allcloseMethod.

_contrib_allclose(a, b, rtol, atol, equal_nan)

This operators implements the numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)

.. math::

f(x) = |a−b|≤atol+rtol|b|

where :math:a, b are the input tensors of equal types an shapes :math:atol, rtol the values of absolute and relative tolerance (by default, rtol=1e-05, atol=1e-08)

Examples::

a = [1e10, 1e-7], b = [1.00001e10, 1e-8] y = allclose(a, b) y = False

a = [1e10, 1e-8], b = [1.00001e10, 1e-9] y = allclose(a, b) y = True

Defined in src/operator/contrib/allclose_op.cc:L55

Arguments

  • a::NDArray-or-SymbolicNode: Input array a
  • b::NDArray-or-SymbolicNode: Input array b
  • rtol::float, optional, default=9.99999975e-06: Relative tolerance.
  • atol::float, optional, default=9.99999994e-09: Absolute tolerance.
  • equal_nan::boolean, optional, default=1: Whether to compare NaN's as equal. If True, NaN's in A will be considered equal to NaN's in B in the output array.

source

# MXNet.mx._contrib_arange_likeMethod.

_contrib_arange_like(data, start, step, repeat, ctx, axis)

Return an array with evenly spaced values. If axis is not given, the output will have the same shape as the input array. Otherwise, the output will be a 1-D array with size of the specified axis in input shape.

Examples::

x = [[0.14883883 0.7772398 0.94865847 0.7225052 ] [0.23729339 0.6112595 0.66538996 0.5132841 ] [0.30822644 0.9912457 0.15502319 0.7043658 ]]

out = mx.nd.contrib.arange_like(x, start=0)

[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]
 <NDArray 3x4 @cpu(0)>

out = mx.nd.contrib.arange_like(x, start=0, axis=-1)

[0. 1. 2. 3.]
<NDArray 4 @cpu(0)>

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • start::double, optional, default=0: Start of interval. The interval includes this value. The default start value is 0.
  • step::double, optional, default=1: Spacing between values.
  • repeat::int, optional, default='1': The repeating time of all elements. E.g repeat=3, the element a will be repeated three times –> a, a, a.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • axis::int or None, optional, default='None': Arange elements according to the size of a certain axis of input array. The negative numbers are interpreted counting from the backward. If not provided, will arange elements according to the input shape.

source

# MXNet.mx._contrib_backward_gradientmultiplierMethod.

_contrib_backward_gradientmultiplier(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._contrib_backward_hawkesllMethod.

_contrib_backward_hawkesll()

Arguments

source

# MXNet.mx._contrib_backward_index_copyMethod.

_contrib_backward_index_copy()

Arguments

source

# MXNet.mx._contrib_backward_quadraticMethod.

_contrib_backward_quadratic()

Arguments

source

# MXNet.mx._contrib_bipartite_matchingMethod.

_contrib_bipartite_matching(data, is_ascend, threshold, topk)

Compute bipartite matching. The matching is performed on score matrix with shape [B, N, M]

  • B: batch_size
  • N: number of rows to match
  • M: number of columns as reference to be matched against.

Returns: x : matched column indices. -1 indicating non-matched elements in rows. y : matched row indices.

Note::

Zero gradients are back-propagated in this op for now.

Example::

s = [[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]]
x, y = bipartite_matching(x, threshold=1e-12, is_ascend=False)
x = [1, -1, 0]
y = [2, 0]

Defined in src/operator/contrib/bounding_box.cc:L182

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • is_ascend::boolean, optional, default=0: Use ascend order for scores instead of descending. Please set threshold accordingly.
  • threshold::float, required: Ignore matching when score < thresh, if isascend=false, or ignore score > thresh, if isascend=true.
  • topk::int, optional, default='-1': Limit the number of matches to topk, set -1 for no limit

source

# MXNet.mx._contrib_boolean_maskMethod.

_contrib_boolean_mask(data, index, axis)

Given an n-d NDArray data, and a 1-d NDArray index, the operator produces an un-predeterminable shaped n-d NDArray out, which stands for the rows in x where the corresonding element in index is non-zero.

data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) index = mx.nd.array([0, 1, 0]) out = mx.nd.contrib.boolean_mask(data, index) out

[[4. 5. 6.]]

Defined in src/operator/contrib/boolean_mask.cc:L195

Arguments

  • data::NDArray-or-SymbolicNode: Data
  • index::NDArray-or-SymbolicNode: Mask
  • axis::int, optional, default='0': An integer that represents the axis in NDArray to mask from.

source

# MXNet.mx._contrib_box_decodeMethod.

_contrib_box_decode(data, anchors, std0, std1, std2, std3, clip, format)

Decode bounding boxes training target with normalized center offsets. Input bounding boxes are using corner type: x_{min}, y_{min}, x_{max}, y_{max} or center type: `x, y, width, height.) array

Defined in src/operator/contrib/bounding_box.cc:L233

Arguments

  • data::NDArray-or-SymbolicNode: (B, N, 4) predicted bbox offset
  • anchors::NDArray-or-SymbolicNode: (1, N, 4) encoded in corner or center
  • std0::float, optional, default=1: value to be divided from the 1st encoded values
  • std1::float, optional, default=1: value to be divided from the 2nd encoded values
  • std2::float, optional, default=1: value to be divided from the 3rd encoded values
  • std3::float, optional, default=1: value to be divided from the 4th encoded values
  • clip::float, optional, default=-1: If larger than 0, bounding box target will be clipped to this value.
  • format::{'center', 'corner'},optional, default='center': The box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

source

# MXNet.mx._contrib_box_encodeMethod.

_contrib_box_encode(samples, matches, anchors, refs, means, stds)

Encode bounding boxes training target with normalized center offsets. Input bounding boxes are using corner type: x_{min}, y_{min}, x_{max}, y_{max}.) array

Defined in src/operator/contrib/bounding_box.cc:L210

Arguments

  • samples::NDArray-or-SymbolicNode: (B, N) value +1 (positive), -1 (negative), 0 (ignore)
  • matches::NDArray-or-SymbolicNode: (B, N) value range [0, M)
  • anchors::NDArray-or-SymbolicNode: (B, N, 4) encoded in corner
  • refs::NDArray-or-SymbolicNode: (B, M, 4) encoded in corner
  • means::NDArray-or-SymbolicNode: (4,) Mean value to be subtracted from encoded values
  • stds::NDArray-or-SymbolicNode: (4,) Std value to be divided from encoded values

source

# MXNet.mx._contrib_box_iouMethod.

_contrib_box_iou(lhs, rhs, format)

Bounding box overlap of two arrays. The overlap is defined as Intersection-over-Union, aka, IOU.

  • lhs: (a1, a2, ..., a_n, 4) array
  • rhs: (b1, b2, ..., b_n, 4) array
  • output: (a1, a2, ..., an, b1, b2, ..., bn) array

Note::

Zero gradients are back-propagated in this op for now.

Example::

x = [[0.5, 0.5, 1.0, 1.0], [0.0, 0.0, 0.5, 0.5]]
y = [[0.25, 0.25, 0.75, 0.75]]
box_iou(x, y, format='corner') = [[0.1428], [0.1428]]

Defined in src/operator/contrib/bounding_box.cc:L136

Arguments

  • lhs::NDArray-or-SymbolicNode: The first input
  • rhs::NDArray-or-SymbolicNode: The second input
  • format::{'center', 'corner'},optional, default='corner': The box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

source

# MXNet.mx._contrib_box_nmsMethod.

_contrib_box_nms(data, overlap_thresh, valid_thresh, topk, coord_start, score_index, id_index, background_id, force_suppress, in_format, out_format)

Apply non-maximum suppression to input.

The output will be sorted in descending order according to score. Boxes with overlaps larger than overlap_thresh, smaller scores and background boxes will be removed and filled with -1, the corresponding position will be recorded for backward propogation.

During back-propagation, the gradient will be copied to the original position according to the input index. For positions that have been suppressed, the in_grad will be assigned 0. In summary, gradients are sticked to its boxes, will either be moved or discarded according to its original index in input.

Input requirements::

  1. Input tensor have at least 2 dimensions, (n, k), any higher dims will be regarded

as batch, e.g. (a, b, c, d, n, k) == (abc*d, n, k)

  1. n is the number of boxes in each batch
  2. k is the width of each box item.

By default, a box is [id, score, xmin, ymin, xmax, ymax, ...], additional elements are allowed.

  • id_index: optional, use -1 to ignore, useful if force_suppress=False, which means we will skip highly overlapped boxes if one is apple while the other is car.
  • background_id: optional, default=-1, class id for background boxes, useful when id_index >= 0 which means boxes with background id will be filtered before nms.
  • coord_start: required, default=2, the starting index of the 4 coordinates. Two formats are supported:

    • corner: [xmin, ymin, xmax, ymax]
    • center: [x, y, width, height]
    • score_index: required, default=1, box score/confidence. When two boxes overlap IOU > overlap_thresh, the one with smaller score will be suppressed.
    • in_format and out_format: default='corner', specify in/out box formats.

Examples::

x = [[0, 0.5, 0.1, 0.1, 0.2, 0.2], [1, 0.4, 0.1, 0.1, 0.2, 0.2], [0, 0.3, 0.1, 0.1, 0.14, 0.14], [2, 0.6, 0.5, 0.5, 0.7, 0.8]] boxnms(x, overlapthresh=0.1, coordstart=2, scoreindex=1, idindex=0, forcesuppress=True, informat='corner', outtyp='corner') = [[2, 0.6, 0.5, 0.5, 0.7, 0.8], [0, 0.5, 0.1, 0.1, 0.2, 0.2], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]] out_grad = [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2, 0.2, 0.2], [0.3, 0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4, 0.4]]

exe.backward

in_grad = [[0.2, 0.2, 0.2, 0.2, 0.2, 0.2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]

Defined in src/operator/contrib/bounding_box.cc:L94

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • overlap_thresh::float, optional, default=0.5: Overlapping(IoU) threshold to suppress object with smaller score.
  • valid_thresh::float, optional, default=0: Filter input boxes to those whose scores greater than valid_thresh.
  • topk::int, optional, default='-1': Apply nms to topk boxes with descending scores, -1 to no restriction.
  • coord_start::int, optional, default='2': Start index of the consecutive 4 coordinates.
  • score_index::int, optional, default='1': Index of the scores/confidence of boxes.
  • id_index::int, optional, default='-1': Optional, index of the class categories, -1 to disable.
  • background_id::int, optional, default='-1': Optional, id of the background class which will be ignored in nms.
  • force_suppress::boolean, optional, default=0: Optional, if set false and id_index is provided, nms will only apply to boxes belongs to the same category
  • in_format::{'center', 'corner'},optional, default='corner': The input box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

  • out_format::{'center', 'corner'},optional, default='corner': The output box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

source

# MXNet.mx._contrib_box_non_maximum_suppressionMethod.

_contrib_box_non_maximum_suppression(data, overlap_thresh, valid_thresh, topk, coord_start, score_index, id_index, background_id, force_suppress, in_format, out_format)

contribboxnonmaximumsuppression is an alias of _contribbox_nms.

Apply non-maximum suppression to input.

The output will be sorted in descending order according to score. Boxes with overlaps larger than overlap_thresh, smaller scores and background boxes will be removed and filled with -1, the corresponding position will be recorded for backward propogation.

During back-propagation, the gradient will be copied to the original position according to the input index. For positions that have been suppressed, the in_grad will be assigned 0. In summary, gradients are sticked to its boxes, will either be moved or discarded according to its original index in input.

Input requirements::

  1. Input tensor have at least 2 dimensions, (n, k), any higher dims will be regarded

as batch, e.g. (a, b, c, d, n, k) == (abc*d, n, k)

  1. n is the number of boxes in each batch
  2. k is the width of each box item.

By default, a box is [id, score, xmin, ymin, xmax, ymax, ...], additional elements are allowed.

  • id_index: optional, use -1 to ignore, useful if force_suppress=False, which means we will skip highly overlapped boxes if one is apple while the other is car.
  • background_id: optional, default=-1, class id for background boxes, useful when id_index >= 0 which means boxes with background id will be filtered before nms.
  • coord_start: required, default=2, the starting index of the 4 coordinates. Two formats are supported:

    • corner: [xmin, ymin, xmax, ymax]
    • center: [x, y, width, height]
    • score_index: required, default=1, box score/confidence. When two boxes overlap IOU > overlap_thresh, the one with smaller score will be suppressed.
    • in_format and out_format: default='corner', specify in/out box formats.

Examples::

x = [[0, 0.5, 0.1, 0.1, 0.2, 0.2], [1, 0.4, 0.1, 0.1, 0.2, 0.2], [0, 0.3, 0.1, 0.1, 0.14, 0.14], [2, 0.6, 0.5, 0.5, 0.7, 0.8]] boxnms(x, overlapthresh=0.1, coordstart=2, scoreindex=1, idindex=0, forcesuppress=True, informat='corner', outtyp='corner') = [[2, 0.6, 0.5, 0.5, 0.7, 0.8], [0, 0.5, 0.1, 0.1, 0.2, 0.2], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]] out_grad = [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2, 0.2, 0.2], [0.3, 0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4, 0.4]]

exe.backward

in_grad = [[0.2, 0.2, 0.2, 0.2, 0.2, 0.2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]

Defined in src/operator/contrib/bounding_box.cc:L94

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • overlap_thresh::float, optional, default=0.5: Overlapping(IoU) threshold to suppress object with smaller score.
  • valid_thresh::float, optional, default=0: Filter input boxes to those whose scores greater than valid_thresh.
  • topk::int, optional, default='-1': Apply nms to topk boxes with descending scores, -1 to no restriction.
  • coord_start::int, optional, default='2': Start index of the consecutive 4 coordinates.
  • score_index::int, optional, default='1': Index of the scores/confidence of boxes.
  • id_index::int, optional, default='-1': Optional, index of the class categories, -1 to disable.
  • background_id::int, optional, default='-1': Optional, id of the background class which will be ignored in nms.
  • force_suppress::boolean, optional, default=0: Optional, if set false and id_index is provided, nms will only apply to boxes belongs to the same category
  • in_format::{'center', 'corner'},optional, default='corner': The input box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

  • out_format::{'center', 'corner'},optional, default='corner': The output box encoding type.

"corner" means boxes are encoded as [xmin, ymin, xmax, ymax], "center" means boxes are encodes as [x, y, width, height].

source

# MXNet.mx._contrib_calibrate_entropyMethod.

_contrib_calibrate_entropy(hist, hist_edges, num_quantized_bins)

Provide calibrated min/max for input histogram.

.. Note:: This operator only supports forward propagation. DO NOT use it in training.

Defined in src/operator/quantization/calibrate.cc:L196

Arguments

  • hist::NDArray-or-SymbolicNode: A ndarray/symbol of type float32
  • hist_edges::NDArray-or-SymbolicNode: A ndarray/symbol of type float32
  • num_quantized_bins::int, optional, default='255': The number of quantized bins.

source

# MXNet.mx._contrib_count_sketchMethod.

_contrib_count_sketch(data, h, s, out_dim, processing_batch_size)

Apply CountSketch to input: map a d-dimension data to k-dimension data"

.. note:: count_sketch is only available on GPU.

Assume input data has shape (N, d), sign hash table s has shape (N, d), index hash table h has shape (N, d) and mapping dimension out_dim = k, each element in s is either +1 or -1, each element in h is random integer from 0 to k-1. Then the operator computs:

.. math:: out[h[i]] += data[i] * s[i]

Example::

outdim = 5 x = [[1.2, 2.5, 3.4],[3.2, 5.7, 6.6]] h = [[0, 3, 4]] s = [[1, -1, 1]] mx.contrib.ndarray.countsketch(data=x, h=h, s=s, out_dim = 5) = [[1.2, 0, 0, -2.5, 3.4], [3.2, 0, 0, -5.7, 6.6]]

Defined in src/operator/contrib/count_sketch.cc:L66

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the CountSketchOp.
  • h::NDArray-or-SymbolicNode: The index vector
  • s::NDArray-or-SymbolicNode: The sign vector
  • out_dim::int, required: The output dimension.
  • processing_batch_size::int, optional, default='32': How many sketch vectors to process at one time.

source

# MXNet.mx._contrib_ctc_lossMethod.

_contrib_ctc_loss(data, label, data_lengths, label_lengths, use_data_lengths, use_label_lengths, blank_label)

contribctc_loss is an alias of CTCLoss.

Connectionist Temporal Classification Loss.

.. note:: The existing alias $contrib_CTCLoss$ is deprecated.

The shapes of the inputs and outputs:

  • data: (sequence_length, batch_size, alphabet_size)
  • label: (batch_size, label_sequence_length)
  • out: (batch_size)

The data tensor consists of sequences of activation vectors (without applying softmax), with i-th channel in the last dimension corresponding to i-th label for i between 0 and alphabet*size-1 (i.e always 0-indexed). Alphabet size should include one additional value reserved for blank label. When blank*labelis"first", the0-th channel is be reserved for activation of blank label, or otherwise if it is "last",(alphabet_size-1)-th channel should be reserved for blank label.

$label$ is an index matrix of integers. When blank_label is $"first"$, the value 0 is then reserved for blank label, and should not be passed in this matrix. Otherwise, when blank_label is $"last"$, the value (alphabet_size-1) is reserved for blank label.

If a sequence of labels is shorter than labelsequencelength, use the special padding value at the end of the sequence to conform it to the correct length. The padding value is 0 when blank_label is $"first"$, and -1 otherwise.

For example, suppose the vocabulary is [a, b, c], and in one batch we have three sequences 'ba', 'cbb', and 'abac'. When blank_label is $"first"$, we can index the labels as {'a': 1, 'b': 2, 'c': 3}, and we reserve the 0-th channel for blank label in data tensor. The resulting label tensor should be padded to be::

[[2, 1, 0, 0], [3, 2, 2, 0], [1, 2, 1, 3]]

When blank_label is $"last"$, we can index the labels as {'a': 0, 'b': 1, 'c': 2}, and we reserve the channel index 3 for blank label in data tensor. The resulting label tensor should be padded to be::

[[1, 0, -1, -1], [2, 1, 1, -1], [0, 1, 0, 2]]

$out$ is a list of CTC loss values, one per example in the batch.

See Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks, A. Graves et al. for more information on the definition and the algorithm.

Defined in src/operator/nn/ctc_loss.cc:L100

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • label::NDArray-or-SymbolicNode: Ground-truth labels for the loss.
  • data_lengths::NDArray-or-SymbolicNode: Lengths of data for each of the samples. Only required when usedatalengths is true.
  • label_lengths::NDArray-or-SymbolicNode: Lengths of labels for each of the samples. Only required when uselabellengths is true.
  • use_data_lengths::boolean, optional, default=0: Whether the data lenghts are decided by data_lengths. If false, the lengths are equal to the max sequence length.
  • use_label_lengths::boolean, optional, default=0: Whether the label lenghts are decided by label_lengths, or derived from padding_mask. If false, the lengths are derived from the first occurrence of the value of padding_mask. The value of padding_mask is $0$ when first CTC label is reserved for blank, and $-1$ when last label is reserved for blank. See blank_label.
  • blank_label::{'first', 'last'},optional, default='first': Set the label that is reserved for blank label.If "first", 0-th label is reserved, and label values for tokens in the vocabulary are between $1$ and $alphabet_size-1$, and the padding mask is $-1$. If "last", last label value $alphabet_size-1$ is reserved for blank label instead, and label values for tokens in the vocabulary are between $0$ and $alphabet_size-2$, and the padding mask is $0$.

source

# MXNet.mx._contrib_dequantizeMethod.

_contrib_dequantize(data, min_range, max_range, out_type)

Dequantize the input tensor into a float tensor. minrange and maxrange are scalar floats that specify the range for the output data.

When input data type is uint8, the output is calculated using the following equation:

out[i] = in[i] * (max_range - min_range) / 255.0,

When input data type is int8, the output is calculate using the following equation by keep zero centered for the quantized value:

out[i] = in[i] * MaxAbs(min_range, max_range) / 127.0,

.. Note:: This operator only supports forward propogation. DO NOT use it in training.

Defined in src/operator/quantization/dequantize.cc:L80

Arguments

  • data::NDArray-or-SymbolicNode: A ndarray/symbol of type uint8
  • min_range::NDArray-or-SymbolicNode: The minimum scalar value possibly produced for the input in float32
  • max_range::NDArray-or-SymbolicNode: The maximum scalar value possibly produced for the input in float32
  • out_type::{'float32'},optional, default='float32': Output data type.

source

# MXNet.mx._contrib_dgl_adjacencyMethod.

_contrib_dgl_adjacency(data)

This operator converts a CSR matrix whose values are edge Ids to an adjacency matrix whose values are ones. The output CSR matrix always has the data value of float32.

Example:

.. code:: python

x = [[ 1, 0, 0 ], [ 0, 2, 0 ], [ 0, 0, 3 ]] dgl_adjacency(x) = [[ 1, 0, 0 ], [ 0, 1, 0 ], [ 0, 0, 1 ]]

Defined in src/operator/contrib/dgl_graph.cc:L1424

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray

source

# MXNet.mx._contrib_dgl_csr_neighbor_non_uniform_sampleMethod.

_contrib_dgl_csr_neighbor_non_uniform_sample(csr_matrix, probability, seed_arrays, num_args, num_hops, num_neighbor, max_num_vertices)

Note: contribdglcsrneighbornonuniformsample takes variable number of positional inputs. So instead of calling as _contribdglcsrneighbornonuniformsample([x, y, z], numargs=3), one should call via contribdglcsrneighbornonuniformsample(x, y, z), and numargs will be determined automatically.

This operator samples sub-graph from a csr graph via an non-uniform probability. The operator is designed for DGL.

The operator outputs four sets of NDArrays to represent the sampled results (the number of NDArrays in each set is the same as the number of seed NDArrays minus two (csr matrix and probability)):

  1. a set of 1D NDArrays containing the sampled vertices, 2) a set of CSRNDArrays representing

the sampled edges, 3) a set of 1D NDArrays with the probability that vertices are sampled,

  1. a set of 1D NDArrays indicating the layer where a vertex is sampled.

The first set of 1D NDArrays have a length of maxnumvertices+1. The last element in an NDArray indicate the acutal number of vertices in a subgraph. The third and fourth set of NDArrays have a length of maxnumvertices, and the valid number of vertices is the same as the ones in the first set.

Example:

.. code:: python

shape = (5, 5) prob = mx.nd.array([0.9, 0.8, 0.2, 0.4, 0.1], dtype=np.float32) datanp = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64) indicesnp = np.array([1,2,3,4,0,2,3,4,0,1,3,4,0,1,2,4,0,1,2,3], dtype=np.int64) indptrnp = np.array([0,4,8,12,16,20], dtype=np.int64) a = mx.nd.sparse.csrmatrix((datanp, indicesnp, indptrnp), shape=shape) seed = mx.nd.array([0,1,2,3,4], dtype=np.int64) out = mx.nd.contrib.dglcsrneighbornonuniformsample(a, prob, seed, numargs=3, numhops=1, numneighbor=2, maxnum_vertices=5)

out[0] [0 1 2 3 4 5]

out[1].asnumpy() array([[ 0, 1, 2, 0, 0], [ 5, 0, 6, 0, 0], [ 9, 10, 0, 0, 0], [13, 14, 0, 0, 0], [ 0, 18, 19, 0, 0]])

out[2] [0.9 0.8 0.2 0.4 0.1]

out[3] [0 0 0 0 0]

Defined in src/operator/contrib/dgl_graph.cc:L911

Arguments

  • csr_matrix::NDArray-or-SymbolicNode: csr matrix
  • probability::NDArray-or-SymbolicNode: probability vector
  • seed_arrays::NDArray-or-SymbolicNode[]: seed vertices
  • num_args::int, required: Number of input NDArray.
  • num_hops::long, optional, default=1: Number of hops.
  • num_neighbor::long, optional, default=2: Number of neighbor.
  • max_num_vertices::long, optional, default=100: Max number of vertices.

source

# MXNet.mx._contrib_dgl_csr_neighbor_uniform_sampleMethod.

_contrib_dgl_csr_neighbor_uniform_sample(csr_matrix, seed_arrays, num_args, num_hops, num_neighbor, max_num_vertices)

Note: contribdglcsrneighboruniformsample takes variable number of positional inputs. So instead of calling as contribdglcsrneighboruniformsample([x, y, z], numargs=3), one should call via _contribdglcsrneighboruniformsample(x, y, z), and num_args will be determined automatically.

This operator samples sub-graphs from a csr graph via an uniform probability. The operator is designed for DGL.

The operator outputs three sets of NDArrays to represent the sampled results (the number of NDArrays in each set is the same as the number of seed NDArrays minus two (csr matrix and probability)):

  1. a set of 1D NDArrays containing the sampled vertices, 2) a set of CSRNDArrays representing

the sampled edges, 3) a set of 1D NDArrays indicating the layer where a vertex is sampled. The first set of 1D NDArrays have a length of maxnumvertices+1. The last element in an NDArray indicate the acutal number of vertices in a subgraph. The third set of NDArrays have a length of maxnumvertices, and the valid number of vertices is the same as the ones in the first set.

Example:

.. code:: python

shape = (5, 5) datanp = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64) indicesnp = np.array([1,2,3,4,0,2,3,4,0,1,3,4,0,1,2,4,0,1,2,3], dtype=np.int64) indptrnp = np.array([0,4,8,12,16,20], dtype=np.int64) a = mx.nd.sparse.csrmatrix((datanp, indicesnp, indptrnp), shape=shape) a.asnumpy() seed = mx.nd.array([0,1,2,3,4], dtype=np.int64) out = mx.nd.contrib.dglcsrneighboruniformsample(a, seed, numargs=2, numhops=1, numneighbor=2, maxnumvertices=5)

out[0] [0 1 2 3 4 5]

out[1].asnumpy() array([[ 0, 1, 0, 3, 0], [ 5, 0, 0, 7, 0], [ 9, 0, 0, 11, 0], [13, 0, 15, 0, 0], [17, 0, 19, 0, 0]])

out[2] [0 0 0 0 0]

Defined in src/operator/contrib/dgl_graph.cc:L801

Arguments

  • csr_matrix::NDArray-or-SymbolicNode: csr matrix
  • seed_arrays::NDArray-or-SymbolicNode[]: seed vertices
  • num_args::int, required: Number of input NDArray.
  • num_hops::long, optional, default=1: Number of hops.
  • num_neighbor::long, optional, default=2: Number of neighbor.
  • max_num_vertices::long, optional, default=100: Max number of vertices.

source

# MXNet.mx._contrib_dgl_graph_compactMethod.

_contrib_dgl_graph_compact(graph_data, num_args, return_mapping, graph_sizes)

Note: contribdglgraphcompact takes variable number of positional inputs. So instead of calling as contribdglgraphcompact([x, y, z], numargs=3), one should call via _contribdglgraphcompact(x, y, z), and num_args will be determined automatically.

This operator compacts a CSR matrix generated by dglcsrneighboruniformsample and dglcsrneighbornonuniform_sample. The CSR matrices generated by these two operators may have many empty rows at the end and many empty columns. This operator removes these empty rows and empty columns.

Example:

.. code:: python

shape = (5, 5) datanp = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64) indicesnp = np.array([1,2,3,4,0,2,3,4,0,1,3,4,0,1,2,4,0,1,2,3], dtype=np.int64) indptrnp = np.array([0,4,8,12,16,20], dtype=np.int64) a = mx.nd.sparse.csrmatrix((datanp, indicesnp, indptrnp), shape=shape) seed = mx.nd.array([0,1,2,3,4], dtype=np.int64) out = mx.nd.contrib.dglcsrneighboruniformsample(a, seed, numargs=2, numhops=1, numneighbor=2, maxnumvertices=6) subgv = out[0] subg = out[1] compact = mx.nd.contrib.dglgraphcompact(subg, subgv, graphsizes=(subgv[-1].asnumpy()[0]), return_mapping=False)

compact.asnumpy() array([[0, 0, 0, 1, 0], [2, 0, 3, 0, 0], [0, 4, 0, 0, 5], [0, 6, 0, 0, 7], [8, 9, 0, 0, 0]])

Defined in src/operator/contrib/dgl_graph.cc:L1613

Arguments

  • graph_data::NDArray-or-SymbolicNode[]: Input graphs and input vertex Ids.
  • num_args::int, required: Number of input arguments.
  • return_mapping::boolean, required: Return mapping of vid and eid between the subgraph and the parent graph.
  • graph_sizes::tuple of <long>, required: the number of vertices in each graph.

source

# MXNet.mx._contrib_dgl_subgraphMethod.

_contrib_dgl_subgraph(graph, data, num_args, return_mapping)

Note: contribdglsubgraph takes variable number of positional inputs. So instead of calling as _contribdglsubgraph([x, y, z], numargs=3), one should call via contribdglsubgraph(x, y, z), and numargs will be determined automatically.

This operator constructs an induced subgraph for a given set of vertices from a graph. The operator accepts multiple sets of vertices as input. For each set of vertices, it returns a pair of CSR matrices if return_mapping is True: the first matrix contains edges with new edge Ids, the second matrix contains edges with the original edge Ids.

Example:

.. code:: python

 x=[[1, 0, 0, 2],
   [3, 0, 4, 0],
   [0, 5, 0, 0],
   [0, 6, 7, 0]]
 v = [0, 1, 2]
 dgl_subgraph(x, v, return_mapping=True) =
   [[1, 0, 0],
    [2, 0, 3],
    [0, 4, 0]],
   [[1, 0, 0],
    [3, 0, 4],
    [0, 5, 0]]

Defined in src/operator/contrib/dgl_graph.cc:L1171

Arguments

  • graph::NDArray-or-SymbolicNode: Input graph where we sample vertices.
  • data::NDArray-or-SymbolicNode[]: The input arrays that include data arrays and states.
  • num_args::int, required: Number of input arguments, including all symbol inputs.
  • return_mapping::boolean, required: Return mapping of vid and eid between the subgraph and the parent graph.

source

# MXNet.mx._contrib_div_sqrt_dimMethod.

_contrib_div_sqrt_dim(data)

Rescale the input by the square root of the channel dimension.

out = data / sqrt(data.shape[-1])

Defined in src/operator/contrib/transformer.cc:L832

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._contrib_edge_idMethod.

_contrib_edge_id(data, u, v)

This operator implements the edge_id function for a graph stored in a CSR matrix (the value of the CSR stores the edge Id of the graph). output[i] = input[u[i], v[i]] if there is an edge between u[i] and v[i]], otherwise output[i] will be -1. Both u and v should be 1D vectors.

Example:

.. code:: python

  x = [[ 1, 0, 0 ],
       [ 0, 2, 0 ],
       [ 0, 0, 3 ]]
  u = [ 0, 0, 1, 1, 2, 2 ]
  v = [ 0, 1, 1, 2, 0, 2 ]
  edge_id(x, u, v) = [ 1, -1, 2, -1, -1, 3 ]

The storage type of $edge_id$ output depends on storage types of inputs

  • edge_id(csr, default, default) = default
  • default and rsp inputs are not supported

Defined in src/operator/contrib/dgl_graph.cc:L1352

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • u::NDArray-or-SymbolicNode: u ndarray
  • v::NDArray-or-SymbolicNode: v ndarray

source

# MXNet.mx._contrib_fftMethod.

_contrib_fft(data, compute_size)

Apply 1D FFT to input"

.. note:: fft is only available on GPU.

Currently accept 2 input data shapes: (N, d) or (N1, N2, N3, d), data can only be real numbers. The output data has shape: (N, 2d) or (N1, N2, N3, 2d). The format is: [real0, imag0, real1, imag1, ...].

Example::

data = np.random.normal(0,1,(3,4)) out = mx.contrib.ndarray.fft(data = mx.nd.array(data,ctx = mx.gpu(0)))

Defined in src/operator/contrib/fft.cc:L55

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the FFTOp.
  • compute_size::int, optional, default='128': Maximum size of sub-batch to be forwarded at one time

source

# MXNet.mx._contrib_getnnzMethod.

_contrib_getnnz(data, axis)

Number of stored values for a sparse tensor, including explicit zeros.

This operator only supports CSR matrix on CPU.

Defined in src/operator/contrib/nnz.cc:L176

Arguments

  • data::NDArray-or-SymbolicNode: Input
  • axis::int or None, optional, default='None': Select between the number of values across the whole matrix, in each column, or in each row.

source

# MXNet.mx._contrib_gradientmultiplierMethod.

_contrib_gradientmultiplier(data, scalar, is_int)

This operator implements the gradient multiplier function. In forward pass it acts as an identity transform. During backpropagation it multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to the preceding layer.

Defined in src/operator/contrib/gradientmultiplierop.cc:L78

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._contrib_group_adagrad_updateMethod.

_contrib_group_adagrad_update(weight, grad, history, lr, rescale_grad, clip_gradient, epsilon)

Update function for Group AdaGrad optimizer.

Referenced from Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but uses only a single learning rate for every row of the parameter array.

Updates are applied by::

grad = clip(grad * rescale_grad, clip_gradient)
history += mean(square(grad), axis=1, keepdims=True)
div = grad / sqrt(history + float_stable_eps)
weight -= div * lr

Weights are updated lazily if the gradient is sparse.

Note that non-zero values for the weight decay option are not supported.

Defined in src/operator/contrib/optimizer_op.cc:L70

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • history::NDArray-or-SymbolicNode: History
  • lr::float, required: Learning rate
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • epsilon::float, optional, default=9.99999975e-06: Epsilon for numerical stability

source

# MXNet.mx._contrib_hawkesllMethod.

_contrib_hawkesll(lda, alpha, beta, state, lags, marks, valid_length, max_time)

Computes the log likelihood of a univariate Hawkes process.

The log likelihood is calculated on point process observations represented as ragged matrices for lags (interarrival times w.r.t. the previous point), and marks (identifiers for the process ID). Note that each mark is considered independent, i.e., computes the joint likelihood of a set of Hawkes processes determined by the conditional intensity:

.. math::

\lambdak^(t) = \lambdak + \alphak \sum{{ti < t, yi = k}} \betak \exp(-\beta*k (t - t_i))

where :math:\lambda_k specifies the background intensity $lda$, :math:\alpha_k specifies the branching ratio or $alpha$, and :math:\beta_k the delay density parameter $beta$.

$lags$ and $marks$ are two NDArrays of shape (N, T) and correspond to the representation of the point process observation, the first dimension corresponds to the batch index, and the second to the sequence. These are "left-aligned" ragged matrices (the first index of the second dimension is the beginning of every sequence. The length of each sequence is given by $valid_length$, of shape (N,) where $valid_length[i]$ corresponds to the number of valid points in $lags[i, :]$ and $marks[i, :]$.

$max_time$ is the length of the observation period of the point process. That is, specifying $max_time[i] = 5$ computes the likelihood of the i-th sample as observed on the time interval :math:(0, 5]. Naturally, the sum of all valid $lags[i, :valid_length[i]]$ must be less than or equal to 5.

The input $state$ specifies the memory of the Hawkes process. Invoking the memoryless property of exponential decays, we compute the memory as

.. math::

s_k(t) = \sum_{t_i < t} \exp(-\beta_k (t - t_i)).

The $state$ to be provided is :math:s_k(0) and carries the added intensity due to past events before the current batch. :math:s_k(T) is returned from the function where :math:T is $max_time[T]$.

Example::

define the Hawkes process parameters

lda = nd.array([1.5, 2.0, 3.0]).tile((N, 1)) alpha = nd.array([0.2, 0.3, 0.4]) # branching ratios should be < 1 beta = nd.array([1.0, 2.0, 3.0])

the "data", or observations

ia_times = nd.array([[6, 7, 8, 9], [1, 2, 3, 4], [3, 4, 5, 6], [8, 9, 10, 11]]) marks = nd.zeros((N, T)).astype(np.int32)

starting "state" of the process

states = nd.zeros((N, K))

validlength = nd.array([1, 2, 3, 4]) # number of valid points in each sequence maxtime = nd.ones((N,)) * 100.0 # length of the observation period

A = nd.contrib.hawkesll( lda, alpha, beta, states, iatimes, marks, validlength, max_time )

References:

  • Bacry, E., Mastromatteo, I., & Muzy, J. F. (2015). Hawkes processes in finance. Market Microstructure and Liquidity , 1(01), 1550005.

Defined in src/operator/contrib/hawkes_ll.cc:L83

Arguments

  • lda::NDArray-or-SymbolicNode: Shape (N, K) The intensity for each of the K processes, for each sample
  • alpha::NDArray-or-SymbolicNode: Shape (K,) The infectivity factor (branching ratio) for each process
  • beta::NDArray-or-SymbolicNode: Shape (K,) The decay parameter for each process
  • state::NDArray-or-SymbolicNode: Shape (N, K) the Hawkes state for each process
  • lags::NDArray-or-SymbolicNode: Shape (N, T) the interarrival times
  • marks::NDArray-or-SymbolicNode: Shape (N, T) the marks (process ids)
  • valid_length::NDArray-or-SymbolicNode: The number of valid points in the process
  • max_time::NDArray-or-SymbolicNode: the length of the interval where the processes were sampled

source

# MXNet.mx._contrib_ifftMethod.

_contrib_ifft(data, compute_size)

Apply 1D ifft to input"

.. note:: ifft is only available on GPU.

Currently accept 2 input data shapes: (N, d) or (N1, N2, N3, d). Data is in format: [real0, imag0, real1, imag1, ...]. Last dimension must be an even number. The output data has shape: (N, d/2) or (N1, N2, N3, d/2). It is only the real part of the result.

Example::

data = np.random.normal(0,1,(3,4)) out = mx.contrib.ndarray.ifft(data = mx.nd.array(data,ctx = mx.gpu(0)))

Defined in src/operator/contrib/ifft.cc:L57

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the IFFTOp.
  • compute_size::int, optional, default='128': Maximum size of sub-batch to be forwarded at one time

source

# MXNet.mx._contrib_index_arrayMethod.

_contrib_index_array(data, axes)

Returns an array of indexes of the input array.

For an input array with shape :math:(d_1, d_2, ..., d_n), index_array returns a :math:(d_1, d_2, ..., d_n, n) array idx, where :math:idx[i_1, i_2, ..., i_n, :] = [i_1, i_2, ..., i_n].

Additionally, when the parameter axes is specified, idx will be a :math:(d_1, d_2, ..., d_n, m) array where m is the length of axes, and the following equality will hold: :math:idx[i_1, i_2, ..., i_n, j] = i_{axes[j]}.

Examples::

x = mx.nd.ones((3, 2))

mx.nd.contrib.index_array(x) = [[[0 0]
                                 [0 1]]

                                [[1 0]
                                 [1 1]]

                                [[2 0]
                                 [2 1]]]

x = mx.nd.ones((3, 2, 2))

mx.nd.contrib.index_array(x, axes=(1, 0)) = [[[[0 0]
                                               [0 0]]

                                              [[1 0]
                                               [1 0]]]


                                             [[[0 1]
                                               [0 1]]

                                              [[1 1]
                                               [1 1]]]


                                             [[[0 2]
                                               [0 2]]

                                              [[1 2]
                                               [1 2]]]]

Defined in src/operator/contrib/index_array.cc:L118

Arguments

  • data::NDArray-or-SymbolicNode: Input data
  • axes::Shape or None, optional, default=None: The axes to include in the index array. Supports negative values.

source

# MXNet.mx._contrib_index_copyMethod.

_contrib_index_copy(old_tensor, index_vector, new_tensor)

Copies the elements of a new_tensor into the old_tensor.

This operator copies the elements by selecting the indices in the order given in index. The output will be a new tensor containing the rest elements of old tensor and the copied elements of new tensor. For example, if index[i] == j, then the i th row of new_tensor is copied to the j th row of output.

The index must be a vector and it must have the same size with the 0 th dimension of new_tensor. Also, the 0 th dimension of old*tensor must >= the 0 th dimension of new*tensor, or an error will be raised.

Examples::

x = mx.nd.zeros((5,3))
t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
index = mx.nd.array([0,4,2])

mx.nd.contrib.index_copy(x, index, t)

[[1. 2. 3.]
 [0. 0. 0.]
 [7. 8. 9.]
 [0. 0. 0.]
 [4. 5. 6.]]
<NDArray 5x3 @cpu(0)>

Defined in src/operator/contrib/index_copy.cc:L183

Arguments

  • old_tensor::NDArray-or-SymbolicNode: Old tensor
  • index_vector::NDArray-or-SymbolicNode: Index vector
  • new_tensor::NDArray-or-SymbolicNode: New tensor to be copied

source

# MXNet.mx._contrib_interleaved_matmul_encdec_qkMethod.

_contrib_interleaved_matmul_encdec_qk(queries, keys_values, heads)

Compute the matrix multiplication between the projections of queries and keys in multihead attention use as encoder-decoder.

the inputs must be a tensor of projections of queries following the layout: (seqlength, batchsize, numheads * headdim)

and a tensor of interleaved projections of values and keys following the layout: (seqlength, batchsize, numheads * headdim * 2)

the equivalent code would be: qproj = mx.nd.transpose(queries, axes=(1, 2, 0, 3)) qproj = mx.nd.reshape(qproj, shape=(-1, 0, 0), reverse=True) qproj = mx.nd.contrib.divsqrtdim(qproj) tmp = mx.nd.reshape(keysvalues, shape=(0, 0, numheads, 2, -1)) kproj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) kproj = mx.nd.reshap(kproj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batchdot(qproj, kproj, transposeb=True)

Defined in src/operator/contrib/transformer.cc:L753

Arguments

  • queries::NDArray-or-SymbolicNode: Queries
  • keys_values::NDArray-or-SymbolicNode: Keys and values interleaved
  • heads::int, required: Set number of heads

source

# MXNet.mx._contrib_interleaved_matmul_encdec_valattMethod.

_contrib_interleaved_matmul_encdec_valatt(keys_values, attention, heads)

Compute the matrix multiplication between the projections of values and the attention weights in multihead attention use as encoder-decoder.

the inputs must be a tensor of interleaved projections of keys and values following the layout: (seqlength, batchsize, numheads * headdim * 2)

and the attention weights following the layout: (batchsize, seqlength, seq_length)

the equivalent code would be:

tmp = mx.nd.reshape(querieskeysvalues, shape=(0, 0, numheads, 3, -1)) vproj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) vproj = mx.nd.reshape(vproj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batchdot(attention, vproj, transposeb=True) output = mx.nd.reshape(output, shape=(-1, numheads, 0, 0), reverse=True) output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) output = mx.nd.reshape(output, shape=(0, 0, -1))

Defined in src/operator/contrib/transformer.cc:L799

Arguments

  • keys_values::NDArray-or-SymbolicNode: Keys and values interleaved
  • attention::NDArray-or-SymbolicNode: Attention maps
  • heads::int, required: Set number of heads

source

# MXNet.mx._contrib_interleaved_matmul_selfatt_qkMethod.

_contrib_interleaved_matmul_selfatt_qk(queries_keys_values, heads)

Compute the matrix multiplication between the projections of queries and keys in multihead attention use as self attention.

the input must be a single tensor of interleaved projections of queries, keys and values following the layout: (seqlength, batchsize, numheads * headdim * 3)

the equivalent code would be: tmp = mx.nd.reshape(querieskeysvalues, shape=(0, 0, numheads, 3, -1)) qproj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) qproj = mx.nd.reshape(qproj, shape=(-1, 0, 0), reverse=True) qproj = mx.nd.contrib.divsqrtdim(qproj) kproj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) kproj = mx.nd.reshap(kproj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batchdot(qproj, kproj, transpose_b=True)

Defined in src/operator/contrib/transformer.cc:L665

Arguments

  • queries_keys_values::NDArray-or-SymbolicNode: Interleaved queries, keys and values
  • heads::int, required: Set number of heads

source

# MXNet.mx._contrib_interleaved_matmul_selfatt_valattMethod.

_contrib_interleaved_matmul_selfatt_valatt(queries_keys_values, attention, heads)

Compute the matrix multiplication between the projections of values and the attention weights in multihead attention use as self attention.

the inputs must be a tensor of interleaved projections of queries, keys and values following the layout: (seqlength, batchsize, numheads * headdim * 3)

and the attention weights following the layout: (batchsize, seqlength, seq_length)

the equivalent code would be: tmp = mx.nd.reshape(querieskeysvalues, shape=(0, 0, numheads, 3, -1)) vproj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3)) vproj = mx.nd.reshape(vproj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batchdot(attention, vproj, transposeb=True) output = mx.nd.reshape(output, shape=(-1, numheads, 0, 0), reverse=True) output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) output = mx.nd.reshape(output, shape=(0, 0, -1))

Defined in src/operator/contrib/transformer.cc:L709

Arguments

  • queries_keys_values::NDArray-or-SymbolicNode: Queries, keys and values interleaved
  • attention::NDArray-or-SymbolicNode: Attention maps
  • heads::int, required: Set number of heads

source

# MXNet.mx._contrib_intgemm_fully_connectedMethod.

_contrib_intgemm_fully_connected(data, weight, scaling, bias, num_hidden, no_bias, flatten, out_type)

Multiply matrices using 8-bit integers. data * weight.

Input tensor arguments are: data weight [scaling] [bias]

data: either float32 or prepared using intgemmpreparedata (in which case it is int8).

weight: must be prepared using intgemmprepareweight.

scaling: present if and only if outtype is float32. If so this is multiplied by the result before adding bias. Typically: scaling = (max passed to intgemmprepareweight)/127.0 if data is in float32 scaling = (maxpassed to intgemmpreparedata)/127.0 * (max passed to intgemmprepareweight)/127.0 if data is in int8

bias: present if and only if !no_bias. This is added to the output after scaling and has the same number of columns as the output.

out_type: type of the output.

Defined in src/operator/contrib/intgemm/intgemmfullyconnected_op.cc:L283

Arguments

  • data::NDArray-or-SymbolicNode: First argument to multiplication. Tensor of float32 (quantized on the fly) or int8 from intgemmpreparedata. If you use a different quantizer, be sure to ban -128. The last dimension must be a multiple of 64.
  • weight::NDArray-or-SymbolicNode: Second argument to multiplication. Tensor of int8 from intgemmprepareweight. The last dimension must be a multiple of 64. The product of non-last dimensions must be a multiple of 8.
  • scaling::NDArray-or-SymbolicNode: Scaling factor to apply if output type is float32.
  • bias::NDArray-or-SymbolicNode: Bias term.
  • num_hidden::int, required: Number of hidden nodes of the output.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • flatten::boolean, optional, default=1: Whether to collapse all but the first axis of the input data tensor.
  • out_type::{'float32', 'int32'},optional, default='float32': Output data type.

source

# MXNet.mx._contrib_intgemm_maxabsoluteMethod.

_contrib_intgemm_maxabsolute(data)

Compute the maximum absolute value in a tensor of float32 fast on a CPU. The tensor's total size must be a multiple of 16 and aligned to a multiple of 64 bytes. mxnet.nd.contrib.intgemm_maxabsolute(arr) == arr.abs().max()

Defined in src/operator/contrib/intgemm/maxabsoluteop.cc:L101

Arguments

  • data::NDArray-or-SymbolicNode: Tensor to compute maximum absolute value of

source

# MXNet.mx._contrib_intgemm_prepare_dataMethod.

_contrib_intgemm_prepare_data(data, maxabs)

This operator converts quantizes float32 to int8 while also banning -128.

It it suitable for preparing an data matrix for use by intgemm's C=data * weights operation.

The float32 values are scaled such that maxabs maps to 127. Typically maxabs = maxabsolute(A).

Defined in src/operator/contrib/intgemm/preparedataop.cc:L112

Arguments

  • data::NDArray-or-SymbolicNode: Activation matrix to be prepared for multiplication.
  • maxabs::NDArray-or-SymbolicNode: Maximum absolute value to be used for scaling. (The values will be multiplied by 127.0 / maxabs.

source

# MXNet.mx._contrib_intgemm_prepare_weightMethod.

_contrib_intgemm_prepare_weight(weight, maxabs, already_quantized)

This operator converts a weight matrix in column-major format to intgemm's internal fast representation of weight matrices. MXNet customarily stores weight matrices in column-major (transposed) format. This operator is not meant to be fast; it is meant to be run offline to quantize a model.

In other words, it prepares weight for the operation C = data * weight^T.

If the provided weight matrix is float32, it will be quantized first. The quantization function is (int8_t)(127.0 / max * weight) where multiplier is provided as argument 1 (the weight matrix is argument 0). Then the matrix will be rearranged into the CPU-dependent format.

If the provided weight matrix is already int8, the matrix will only be rearranged into the CPU-dependent format. This way one can quantize with intgemmpreparedata (which just quantizes), store to disk in a consistent format, then at load time convert to CPU-dependent format with intgemmprepareweight.

The internal representation depends on register length. So AVX512, AVX2, and SSSE3 have different formats. AVX512BW and AVX512VNNI have the same representation.

Defined in src/operator/contrib/intgemm/prepareweightop.cc:L153

Arguments

  • weight::NDArray-or-SymbolicNode: Parameter matrix to be prepared for multiplication.
  • maxabs::NDArray-or-SymbolicNode: Maximum absolute value for scaling. The weights will be multipled by 127.0 / maxabs.
  • already_quantized::boolean, optional, default=0: Is the weight matrix already quantized?

source

# MXNet.mx._contrib_intgemm_take_weightMethod.

_contrib_intgemm_take_weight(weight, indices)

Index a weight matrix stored in intgemm's weight format. The indices select the outputs of matrix multiplication, not the inner dot product dimension.

Defined in src/operator/contrib/intgemm/takeweightop.cc:L128

Arguments

  • weight::NDArray-or-SymbolicNode: Tensor already in intgemm weight format to select from
  • indices::NDArray-or-SymbolicNode: indices to select on the 0th dimension of weight

source

# MXNet.mx._contrib_quadraticMethod.

_contrib_quadratic(data, a, b, c)

This operators implements the quadratic function.

.. math:: f(x) = ax^2+bx+c

where :math:x is an input tensor and all operations in the function are element-wise.

Example::

x = [[1, 2], [3, 4]] y = quadratic(data=x, a=1, b=2, c=3) y = [[6, 11], [18, 27]]

The storage type of $quadratic$ output depends on storage types of inputs

  • quadratic(csr, a, b, 0) = csr
  • quadratic(default, a, b, c) = default

Defined in src/operator/contrib/quadratic_op.cc:L50

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • a::float, optional, default=0: Coefficient of the quadratic term in the quadratic function.
  • b::float, optional, default=0: Coefficient of the linear term in the quadratic function.
  • c::float, optional, default=0: Constant term in the quadratic function.

source

# MXNet.mx._contrib_quantizeMethod.

_contrib_quantize(data, min_range, max_range, out_type)

Quantize a input tensor from float to out_type, with user-specified min_range and max_range.

minrange and maxrange are scalar floats that specify the range for the input data.

When out_type is uint8, the output is calculated using the following equation:

out[i] = (in[i] - min_range) * range(OUTPUT_TYPE) / (max_range - min_range) + 0.5,

where range(T) = numeric_limits<T>::max() - numeric_limits<T>::min().

When out_type is int8, the output is calculate using the following equation by keep zero centered for the quantized value:

out[i] = sign(in[i]) * min(abs(in[i] * scale + 0.5f, quantized_range),

where quantized_range = MinAbs(max(int8), min(int8)) and scale = quantized_range / MaxAbs(min_range, max_range).

.. Note:: This operator only supports forward propagation. DO NOT use it in training.

Defined in src/operator/quantization/quantize.cc:L73

Arguments

  • data::NDArray-or-SymbolicNode: A ndarray/symbol of type float32
  • min_range::NDArray-or-SymbolicNode: The minimum scalar value possibly produced for the input
  • max_range::NDArray-or-SymbolicNode: The maximum scalar value possibly produced for the input
  • out_type::{'int8', 'uint8'},optional, default='uint8': Output data type.

source

# MXNet.mx._contrib_quantize_asymMethod.

_contrib_quantize_asym(data, min_calib_range, max_calib_range)

Quantize a input tensor from float to uint8*t. Output scale and shift are scalar floats that specify the quantization parameters for the input data. The output is calculated using the following equation: out[i] = in[i] * scale + shift + 0.5, where scale = uint8*range / (max*range - min*range)andshift = numeric*limits<T>::max - max*range * scale. .. Note:: This operator only supports forward propagation. DO NOT use it in training.

Defined in src/operator/quantization/quantize_asym.cc:L115

Arguments

  • data::NDArray-or-SymbolicNode: A ndarray/symbol of type float32
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32. If present, it will be used to quantize the fp32 data.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32. If present, it will be used to quantize the fp32 data.

source

# MXNet.mx._contrib_quantize_v2Method.

_contrib_quantize_v2(data, out_type, min_calib_range, max_calib_range)

Quantize a input tensor from float to out_type, with user-specified min_calib_range and max_calib_range or the input range collected at runtime.

Output min_range and max_range are scalar floats that specify the range for the input data.

When out_type is uint8, the output is calculated using the following equation:

out[i] = (in[i] - min_range) * range(OUTPUT_TYPE) / (max_range - min_range) + 0.5,

where range(T) = numeric_limits<T>::max() - numeric_limits<T>::min().

When out_type is int8, the output is calculate using the following equation by keep zero centered for the quantized value:

out[i] = sign(in[i]) * min(abs(in[i] * scale + 0.5f, quantized_range),

where quantized_range = MinAbs(max(int8), min(int8)) and scale = quantized_range / MaxAbs(min_range, max_range).

When outtype is auto, the output type is automatically determined by mincalibrange if presented. If mincalibrange < 0.0f, the output type will be int8, otherwise will be uint8. If mincalib_range isn't presented, the output type will be int8.

.. Note:: This operator only supports forward propagation. DO NOT use it in training.

Defined in src/operator/quantization/quantize_v2.cc:L90

Arguments

  • data::NDArray-or-SymbolicNode: A ndarray/symbol of type float32
  • out_type::{'auto', 'int8', 'uint8'},optional, default='int8': Output data type. auto can be specified to automatically determine output type according to mincalibrange.
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32. If present, it will be used to quantize the fp32 data into int8 or uint8.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32. If present, it will be used to quantize the fp32 data into int8 or uint8.

source

# MXNet.mx._contrib_quantized_actMethod.

_contrib_quantized_act(data, min_data, max_data, act_type)

Activation operator for input and output data type of int8. The input and output data comes with min and max thresholds for quantizing the float32 data into int8.

.. Note:: This operator only supports forward propogation. DO NOT use it in training. This operator only supports relu

Defined in src/operator/quantization/quantized_activation.cc:L90

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • min_data::NDArray-or-SymbolicNode: Minimum value of data.
  • max_data::NDArray-or-SymbolicNode: Maximum value of data.
  • act_type::{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required: Activation function to be applied.

source

# MXNet.mx._contrib_quantized_batch_normMethod.

_contrib_quantized_batch_norm(data, gamma, beta, moving_mean, moving_var, min_data, max_data, eps, momentum, fix_gamma, use_global_stats, output_mean_var, axis, cudnn_off, min_calib_range, max_calib_range)

BatchNorm operator for input and output data type of int8. The input and output data comes with min and max thresholds for quantizing the float32 data into int8.

.. Note:: This operator only supports forward propogation. DO NOT use it in training.

Defined in src/operator/quantization/quantizedbatchnorm.cc:L94

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • gamma::NDArray-or-SymbolicNode: gamma.
  • beta::NDArray-or-SymbolicNode: beta.
  • moving_mean::NDArray-or-SymbolicNode: moving_mean.
  • moving_var::NDArray-or-SymbolicNode: moving_var.
  • min_data::NDArray-or-SymbolicNode: Minimum value of data.
  • max_data::NDArray-or-SymbolicNode: Maximum value of data.
  • eps::double, optional, default=0.0010000000474974513: Epsilon to prevent div 0. Must be no less than CUDNNBNMIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5)
  • momentum::float, optional, default=0.899999976: Momentum for moving average
  • fix_gamma::boolean, optional, default=1: Fix gamma while training
  • use_global_stats::boolean, optional, default=0: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator.
  • output_mean_var::boolean, optional, default=0: Output the mean and inverse std
  • axis::int, optional, default='1': Specify which shape axis the channel is specified
  • cudnn_off::boolean, optional, default=0: Do not select CUDNN operator, if available
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.

source

# MXNet.mx._contrib_quantized_concatMethod.

_contrib_quantized_concat(data, num_args, dim)

Note: contribquantizedconcat takes variable number of positional inputs. So instead of calling as _contribquantizedconcat([x, y, z], numargs=3), one should call via contribquantizedconcat(x, y, z), and numargs will be determined automatically.

Joins input arrays along a given axis.

The dimensions of the input arrays should be the same except the axis along which they will be concatenated. The dimension of the output array along the concatenated axis will be equal to the sum of the corresponding dimensions of the input arrays. All inputs with different min/max will be rescaled by using largest [min, max] pairs. If any input holds int8, then the output will be int8. Otherwise output will be uint8.

Defined in src/operator/quantization/quantized_concat.cc:L107

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx._contrib_quantized_convMethod.

_contrib_quantized_conv(data, weight, bias, min_data, max_data, min_weight, max_weight, min_bias, max_bias, kernel, stride, dilate, pad, num_filter, num_group, workspace, no_bias, cudnn_tune, cudnn_off, layout)

Convolution operator for input, weight and bias data type of int8, and accumulates in type int32 for the output. For each argument, two more arguments of type float32 must be provided representing the thresholds of quantizing argument from data type float32 to int8. The final outputs contain the convolution result in int32, and min and max thresholds representing the threholds for quantizing the float32 output into int32.

.. Note:: This operator only supports forward propogation. DO NOT use it in training.

Defined in src/operator/quantization/quantized_conv.cc:L187

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • weight::NDArray-or-SymbolicNode: weight.
  • bias::NDArray-or-SymbolicNode: bias.
  • min_data::NDArray-or-SymbolicNode: Minimum value of data.
  • max_data::NDArray-or-SymbolicNode: Maximum value of data.
  • min_weight::NDArray-or-SymbolicNode: Minimum value of weight.
  • max_weight::NDArray-or-SymbolicNode: Maximum value of weight.
  • min_bias::NDArray-or-SymbolicNode: Minimum value of bias.
  • max_bias::NDArray-or-SymbolicNode: Maximum value of bias.
  • kernel::Shape(tuple), required: Convolution kernel size: (w,), (h, w) or (d, h, w)
  • stride::Shape(tuple), optional, default=[]: Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.
  • num_filter::int (non-negative), required: Convolution filter(channel) number
  • num_group::int (non-negative), optional, default=1: Number of group partitions.
  • workspace::long (non-negative), optional, default=1024: Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages. When CUDNN is not used, it determines the effective batch size of the convolution kernel. When CUDNN is used, it controls the maximum temporary storage used for tuning the best CUDNN kernel when limited_workspace strategy is used.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • cudnn_tune::{None, 'fastest', 'limited_workspace', 'off'},optional, default='None': Whether to pick convolution algo by running performance test.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn for this layer.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None': Set layout for input, output and weight. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.NHWC and NDHWC are only supported on GPU.

source

# MXNet.mx._contrib_quantized_elemwise_addMethod.

_contrib_quantized_elemwise_add(lhs, rhs, lhs_min, lhs_max, rhs_min, rhs_max)

elemwise_add operator for input dataA and input dataB data type of int8, and accumulates in type int32 for the output. For each argument, two more arguments of type float32 must be provided representing the thresholds of quantizing argument from data type float32 to int8. The final outputs contain result in int32, and min and max thresholds representing the threholds for quantizing the float32 output into int32.

.. Note:: This operator only supports forward propogation. DO NOT use it in training.

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • lhs_min::NDArray-or-SymbolicNode: 3rd input
  • lhs_max::NDArray-or-SymbolicNode: 4th input
  • rhs_min::NDArray-or-SymbolicNode: 5th input
  • rhs_max::NDArray-or-SymbolicNode: 6th input

source

# MXNet.mx._contrib_quantized_elemwise_mulMethod.

_contrib_quantized_elemwise_mul(lhs, rhs, lhs_min, lhs_max, rhs_min, rhs_max, min_calib_range, max_calib_range, enable_float_output)

Multiplies arguments int8 element-wise.

Defined in src/operator/quantization/quantizedelemwisemul.cc:L221

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input
  • lhs_min::NDArray-or-SymbolicNode: Minimum value of first input.
  • lhs_max::NDArray-or-SymbolicNode: Maximum value of first input.
  • rhs_min::NDArray-or-SymbolicNode: Minimum value of second input.
  • rhs_max::NDArray-or-SymbolicNode: Maximum value of second input.
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to requantize the int8 output data.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to requantize the int8 output data.
  • enable_float_output::boolean, optional, default=0: Whether to enable float32 output

source

# MXNet.mx._contrib_quantized_embeddingMethod.

_contrib_quantized_embedding(data, weight, min_weight, max_weight, input_dim, output_dim, dtype, sparse_grad)

Maps integer indices to int8 vector representations (embeddings).

Defined in src/operator/quantization/quantizedindexingop.cc:L133

Arguments

  • data::NDArray-or-SymbolicNode: The input array to the embedding operator.
  • weight::NDArray-or-SymbolicNode: The embedding weight matrix.
  • min_weight::NDArray-or-SymbolicNode: Minimum value of data.
  • max_weight::NDArray-or-SymbolicNode: Maximum value of data.
  • input_dim::int, required: Vocabulary size of the input indices.
  • output_dim::int, required: Dimension of the embedding vectors.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data type of weight.
  • sparse_grad::boolean, optional, default=0: Compute row sparse gradient in the backward calculation. If set to True, the grad's storage type is row_sparse.

source

# MXNet.mx._contrib_quantized_flattenMethod.

_contrib_quantized_flatten(data, min_data, max_data)

Arguments

  • data::NDArray-or-SymbolicNode: A ndarray/symbol of type float32
  • min_data::NDArray-or-SymbolicNode: The minimum scalar value possibly produced for the data
  • max_data::NDArray-or-SymbolicNode: The maximum scalar value possibly produced for the data

source

# MXNet.mx._contrib_quantized_fully_connectedMethod.

_contrib_quantized_fully_connected(data, weight, bias, min_data, max_data, min_weight, max_weight, min_bias, max_bias, num_hidden, no_bias, flatten)

Fully Connected operator for input, weight and bias data type of int8, and accumulates in type int32 for the output. For each argument, two more arguments of type float32 must be provided representing the thresholds of quantizing argument from data type float32 to int8. The final outputs contain the convolution result in int32, and min and max thresholds representing the threholds for quantizing the float32 output into int32.

.. Note:: This operator only supports forward propogation. DO NOT use it in training.

Defined in src/operator/quantization/quantizedfullyconnected.cc:L312

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • weight::NDArray-or-SymbolicNode: weight.
  • bias::NDArray-or-SymbolicNode: bias.
  • min_data::NDArray-or-SymbolicNode: Minimum value of data.
  • max_data::NDArray-or-SymbolicNode: Maximum value of data.
  • min_weight::NDArray-or-SymbolicNode: Minimum value of weight.
  • max_weight::NDArray-or-SymbolicNode: Maximum value of weight.
  • min_bias::NDArray-or-SymbolicNode: Minimum value of bias.
  • max_bias::NDArray-or-SymbolicNode: Maximum value of bias.
  • num_hidden::int, required: Number of hidden nodes of the output.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • flatten::boolean, optional, default=1: Whether to collapse all but the first axis of the input data tensor.

source

# MXNet.mx._contrib_quantized_poolingMethod.

_contrib_quantized_pooling(data, min_data, max_data, kernel, pool_type, global_pool, cudnn_off, pooling_convention, stride, pad, p_value, count_include_pad, layout)

Pooling operator for input and output data type of int8. The input and output data comes with min and max thresholds for quantizing the float32 data into int8.

.. Note:: This operator only supports forward propogation. DO NOT use it in training. This operator only supports pool_type of avg or max.

Defined in src/operator/quantization/quantized_pooling.cc:L186

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • min_data::NDArray-or-SymbolicNode: Minimum value of data.
  • max_data::NDArray-or-SymbolicNode: Maximum value of data.
  • kernel::Shape(tuple), optional, default=[]: Pooling kernel size: (y, x) or (d, y, x)
  • pool_type::{'avg', 'lp', 'max', 'sum'},optional, default='max': Pooling type to be applied.
  • global_pool::boolean, optional, default=0: Ignore kernel size, do global pooling based on current input feature map.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn pooling and use MXNet pooling operator.
  • pooling_convention::{'full', 'same', 'valid'},optional, default='valid': Pooling convention to be applied.
  • stride::Shape(tuple), optional, default=[]: Stride: for pooling (y, x) or (d, y, x). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Pad for pooling: (y, x) or (d, y, x). Defaults to no padding.
  • p_value::int or None, optional, default='None': Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling.
  • count_include_pad::boolean or None, optional, default=None: Only used for AvgPool, specify whether to count padding elements for averagecalculation. For example, with a 55 kernel on a 33 corner of a image,the sum of the 9 valid elements will be divided by 25 if this is set to true,or it will be divided by 9 if this is set to false. Defaults to true.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC', 'NWC'},optional, default='None': Set layout for input and output. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.

source

# MXNet.mx._contrib_quantized_rnnMethod.

_contrib_quantized_rnn(data, parameters, state, state_cell, data_scale, data_shift, state_size, num_layers, bidirectional, mode, p, state_outputs, projection_size, lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, use_sequence_length)

RNN operator for input data type of uint8. The weight of each gates is converted to int8, while bias is accumulated in type float32. The hidden state and cell state are in type float32. For the input data, two more arguments of type float32 must be provided representing the thresholds of quantizing argument from data type float32 to uint8. The final outputs contain the recurrent result in float32. It only supports quantization for Vanilla LSTM network. .. Note:: This operator only supports forward propagation. DO NOT use it in training.

Defined in src/operator/quantization/quantized_rnn.cc:L298

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • parameters::NDArray-or-SymbolicNode: weight.
  • state::NDArray-or-SymbolicNode: initial hidden state of the RNN
  • state_cell::NDArray-or-SymbolicNode: initial cell state for LSTM networks (only for LSTM)
  • data_scale::NDArray-or-SymbolicNode: quantization scale of data.
  • data_shift::NDArray-or-SymbolicNode: quantization shift of data.
  • state_size::int (non-negative), required: size of the state for each layer
  • num_layers::int (non-negative), required: number of stacked layers
  • bidirectional::boolean, optional, default=0: whether to use bidirectional recurrent layers
  • mode::{'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required: the type of RNN to compute
  • p::float, optional, default=0: drop rate of the dropout on the outputs of each RNN layer, except the last layer.
  • state_outputs::boolean, optional, default=0: Whether to have the states as symbol outputs.
  • projection_size::int or None, optional, default='None': size of project size
  • lstm_state_clip_min::double or None, optional, default=None: Minimum clip value of LSTM states. This option must be used together with lstmstateclip_max.
  • lstm_state_clip_max::double or None, optional, default=None: Maximum clip value of LSTM states. This option must be used together with lstmstateclip_min.
  • lstm_state_clip_nan::boolean, optional, default=0: Whether to stop NaN from propagating in state by clipping it to min/max. If clipping range is not specified, this option is ignored.
  • use_sequence_length::boolean, optional, default=0: If set to true, this layer takes in an extra input parameter sequence_length to specify variable length sequence

source

# MXNet.mx._contrib_requantizeMethod.

_contrib_requantize(data, min_range, max_range, out_type, min_calib_range, max_calib_range)

Given data that is quantized in int32 and the corresponding thresholds, requantize the data into int8 using min and max thresholds either calculated at runtime or from calibration. It's highly recommended to pre-calucate the min and max thresholds through calibration since it is able to save the runtime of the operator and improve the inference accuracy.

.. Note:: This operator only supports forward propogation. DO NOT use it in training.

Defined in src/operator/quantization/requantize.cc:L59

Arguments

  • data::NDArray-or-SymbolicNode: A ndarray/symbol of type int32
  • min_range::NDArray-or-SymbolicNode: The original minimum scalar value in the form of float32 used for quantizing data into int32.
  • max_range::NDArray-or-SymbolicNode: The original maximum scalar value in the form of float32 used for quantizing data into int32.
  • out_type::{'auto', 'int8', 'uint8'},optional, default='int8': Output data type. auto can be specified to automatically determine output type according to mincalibrange.
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to requantize the int32 data into int8.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to requantize the int32 data into int8.

source

# MXNet.mx._contrib_round_steMethod.

_contrib_round_ste(data)

Straight-through-estimator of round().

In forward pass, returns element-wise rounded value to the nearest integer of the input (same as round()).

In backward pass, returns gradients of $1$ everywhere (instead of $0$ everywhere as in round()): :math:\frac{d}{dx}{round\_ste(x)} = 1 vs. :math:\frac{d}{dx}{round(x)} = 0. This is useful for quantized training.

Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.

Example:: x = round_ste([-1.5, 1.5, -1.9, 1.9, 2.7]) x.backward() x = [-2., 2., -2., 2., 3.] x.grad() = [1., 1., 1., 1., 1.]

The storage type of $round_ste$ output depends upon the input storage type:

  • round_ste(default) = default
  • roundste(rowsparse) = row_sparse
  • round_ste(csr) = csr

Defined in src/operator/contrib/stes_op.cc:L54

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._contrib_sign_steMethod.

_contrib_sign_ste(data)

Straight-through-estimator of sign().

In forward pass, returns element-wise sign of the input (same as sign()).

In backward pass, returns gradients of $1$ everywhere (instead of $0$ everywhere as in $sign()$): :math:\frac{d}{dx}{sign\_ste(x)} = 1 vs. :math:\frac{d}{dx}{sign(x)} = 0. This is useful for quantized training.

Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.

Example:: x = sign_ste([-2, 0, 3]) x.backward() x = [-1., 0., 1.] x.grad() = [1., 1., 1.]

The storage type of $sign_ste$ output depends upon the input storage type:

  • round_ste(default) = default
  • roundste(rowsparse) = row_sparse
  • round_ste(csr) = csr

Defined in src/operator/contrib/stes_op.cc:L79

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._copyMethod.

_copy(data)

Returns a copy of the input.

From:src/operator/tensor/elemwiseunaryop_basic.cc:244

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._copytoMethod.

_copyto(data)

Arguments

  • data::NDArray: input data

source

# MXNet.mx._crop_assignMethod.

_crop_assign(lhs, rhs, begin, end, step)

cropassign is an alias of sliceassign.

Assign the rhs to a cropped subset of lhs.

Requirements

  • output should be explicitly given and be the same as lhs.
  • lhs and rhs are of the same data type, and on the same device.

From:src/operator/tensor/matrix_op.cc:514

Arguments

  • lhs::NDArray-or-SymbolicNode: Source input
  • rhs::NDArray-or-SymbolicNode: value to assign
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._crop_assign_scalarMethod.

_crop_assign_scalar(data, scalar, begin, end, step)

cropassignscalar is an alias of _sliceassign_scalar.

(Assign the scalar to a cropped subset of the input.

Requirements

  • output should be explicitly given and be the same as input

)

From:src/operator/tensor/matrix_op.cc:540

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • scalar::double, optional, default=0: The scalar value for assignment.
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._cvcopyMakeBorderMethod.

_cvcopyMakeBorder(src, top, bot, left, right, type, value, values)

Pad image border with OpenCV.

Arguments

  • src::NDArray: source image
  • top::int, required: Top margin.
  • bot::int, required: Bottom margin.
  • left::int, required: Left margin.
  • right::int, required: Right margin.
  • type::int, optional, default='0': Filling type (default=cv2.BORDER_CONSTANT).
  • value::double, optional, default=0: (Deprecated! Use $values$ instead.) Fill with single value.
  • values::tuple of <double>, optional, default=[]: Fill with value(RGB[A] or gray), up to 4 channels.

source

# MXNet.mx._cvimdecodeMethod.

_cvimdecode(buf, flag, to_rgb)

Decode image with OpenCV. Note: return image in RGB by default, instead of OpenCV's default BGR.

Arguments

  • buf::NDArray: Buffer containing binary encoded image
  • flag::int, optional, default='1': Convert decoded image to grayscale (0) or color (1).
  • to_rgb::boolean, optional, default=1: Whether to convert decoded image to mxnet's default RGB format (instead of opencv's default BGR).

source

# MXNet.mx._cvimreadMethod.

_cvimread(filename, flag, to_rgb)

Read and decode image with OpenCV. Note: return image in RGB by default, instead of OpenCV's default BGR.

Arguments

  • filename::string, required: Name of the image file to be loaded.
  • flag::int, optional, default='1': Convert decoded image to grayscale (0) or color (1).
  • to_rgb::boolean, optional, default=1: Whether to convert decoded image to mxnet's default RGB format (instead of opencv's default BGR).

source

# MXNet.mx._cvimresizeMethod.

_cvimresize(src, w, h, interp)

Resize image with OpenCV.

Arguments

  • src::NDArray: source image
  • w::int, required: Width of resized image.
  • h::int, required: Height of resized image.
  • interp::int, optional, default='1': Interpolation method (default=cv2.INTER_LINEAR).

source

# MXNet.mx._div_scalarMethod.

_div_scalar(data, scalar, is_int)

Divide an array with a scalar.

$_div_scalar$ only operates on data array of input if input is sparse.

For example, if input of shape (100, 100) has only 2 non zero elements, i.e. input.data = [5, 6], scalar = nan, it will result output.data = [nan, nan] instead of 10000 nans.

Defined in src/operator/tensor/elemwisebinaryscalaropbasic.cc:L174

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._equalMethod.

_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._equal_scalarMethod.

_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._eyeMethod.

_eye(N, M, k, ctx, dtype)

Return a 2-D array with ones on the diagonal and zeros elsewhere.

Arguments

  • N::long, required: Number of rows in the output.
  • M::long, optional, default=0: Number of columns in the output. If 0, defaults to N
  • k::long, optional, default=0: Index of the diagonal. 0 (the default) refers to the main diagonal.A positive value refers to an upper diagonal.A negative value to a lower diagonal.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._foreachMethod.

_foreach(fn, data, num_args, num_outputs, num_out_data, in_state_locs, in_data_locs, remain_locs)

Note: foreach takes variable number of positional inputs. So instead of calling as _foreach([x, y, z], numargs=3), one should call via foreach(x, y, z), and numargs will be determined automatically.

Run a for loop over an NDArray with user-defined computation

From:src/operator/control_flow.cc:1090

Arguments

  • fn::SymbolicNode: Input graph.
  • data::NDArray-or-SymbolicNode[]: The input arrays that include data arrays and states.
  • num_args::int, required: Number of inputs.
  • num_outputs::int, required: The number of outputs of the subgraph.
  • num_out_data::int, required: The number of output data of the subgraph.
  • in_state_locs::tuple of <long>, required: The locations of loop states among the inputs.
  • in_data_locs::tuple of <long>, required: The locations of input data among the inputs.
  • remain_locs::tuple of <long>, required: The locations of remaining data among the inputs.

source

# MXNet.mx._get_ndarray_function_defMethod.

The libxmnet APIs are automatically imported from libmxnet.so. The functions listed here operate on NDArray objects. The arguments to the functions are typically ordered as

  func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ..., arg_out1, arg_out2, ...)

unless NDARRAY_ARG_BEFORE_SCALAR is not set. In this case, the scalars are put before the input arguments:

  func_name(scalar1, scalar2, ..., arg_in1, arg_in2, ..., arg_out1, arg_out2, ...)

If ACCEPT_EMPTY_MUTATE_TARGET is set. An overloaded function without the output arguments will also be defined:

  func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ...)

Upon calling, the output arguments will be automatically initialized with empty NDArrays.

Those functions always return the output arguments. If there is only one output (the typical situation), that object (NDArray) is returned. Otherwise, a tuple containing all the outputs will be returned.

source

# MXNet.mx._grad_addMethod.

_grad_add(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._greaterMethod.

_greater(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._greater_equalMethod.

_greater_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._greater_equal_scalarMethod.

_greater_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._greater_scalarMethod.

_greater_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._histogramMethod.

_histogram(data, bins, bin_cnt, range)

This operators implements the histogram function.

Example:: x = [[0, 1], [2, 2], [3, 4]] histo, binedges = histogram(data=x, binbounds=[], bincnt=5, range=(0,5)) histo = [1, 1, 2, 1, 1] binedges = [0., 1., 2., 3., 4.] histo, binedges = histogram(data=x, binbounds=[0., 2.1, 3.]) histo = [4, 1]

Defined in src/operator/tensor/histogram.cc:L137

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • bins::NDArray-or-SymbolicNode: Input ndarray
  • bin_cnt::int or None, optional, default='None': Number of bins for uniform case
  • range::, optional, default=None: The lower and upper range of the bins. if not provided, range is simply (a.min(), a.max()). values outside the range are ignored. the first element of the range must be less than or equal to the second. range affects the automatic bin computation as well. while bin width is computed to be optimal based on the actual data within range, the bin count will fill the entire range including portions containing no data.

source

# MXNet.mx._hypotMethod.

_hypot(lhs, rhs)

Given the "legs" of a right triangle, return its hypotenuse.

Defined in src/operator/tensor/elemwisebinaryop_extended.cc:L78

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._hypot_scalarMethod.

_hypot_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._identity_with_attr_like_rhsMethod.

_identity_with_attr_like_rhs(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input.
  • rhs::NDArray-or-SymbolicNode: Second input.

source

# MXNet.mx._image_adjust_lightingMethod.

_image_adjust_lighting(data, alpha)

Adjust the lighting level of the input. Follow the AlexNet style.

Defined in src/operator/image/image_random.cc:L254

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • alpha::tuple of <float>, required: The lighting alphas for the R, G, B channels.

source

# MXNet.mx._image_cropMethod.

_image_crop(data, x, y, width, height)

Crop an image NDArray of shape (H x W x C) or (N x H x W x C) to the given size. Example: .. code-block:: python image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) mx.nd.image.crop(image, 1, 1, 2, 2) [[[144 34 4] [ 82 157 38]]

         [[156 111 230]
          [177  25  15]]]
        <NDArray 2x2x3 @cpu(0)>
    image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
    mx.nd.image.crop(image, 1, 1, 2, 2)            
        [[[[ 35 198  50]
           [242  94 168]]

          [[223 119 129]
           [249  14 154]]]


          [[[137 215 106]
            [ 79 174 133]]

           [[116 142 109]
            [ 35 239  50]]]]
        <NDArray 2x2x2x3 @cpu(0)>

Defined in src/operator/image/crop.cc:L65

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • x::int, required: Left boundary of the cropping area.
  • y::int, required: Top boundary of the cropping area.
  • width::int, required: Width of the cropping area.
  • height::int, required: Height of the cropping area.

source

# MXNet.mx._image_flip_left_rightMethod.

_image_flip_left_right(data)

Defined in src/operator/image/image_random.cc:L195

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._image_flip_top_bottomMethod.

_image_flip_top_bottom(data)

Defined in src/operator/image/image_random.cc:L205

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._image_normalizeMethod.

_image_normalize(data, mean, std)

Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and standard deviation.

Given mean `(m1, ..., mn)` and std `(s\ :sub:`1`\ , ..., s\ :sub:`n`)` for `n` channels,
this transform normalizes each channel of the input tensor with:

.. math::

    output[i] = (input[i] - m\ :sub:`i`\ ) / s\ :sub:`i`

If mean or std is scalar, the same value will be applied to all channels.

Default value for mean is 0.0 and stand deviation is 1.0.

Example:

.. code-block:: python
    image = mx.nd.random.uniform(0, 1, (3, 4, 2))
    normalize(image, mean=(0, 1, 2), std=(3, 2, 1))
        [[[ 0.18293785  0.19761486]
          [ 0.23839645  0.28142193]
          [ 0.20092112  0.28598186]
          [ 0.18162774  0.28241724]]
         [[-0.2881726  -0.18821815]
          [-0.17705294 -0.30780914]
          [-0.2812064  -0.3512327 ]
          [-0.05411351 -0.4716435 ]]
         [[-1.0363373  -1.7273437 ]
          [-1.6165586  -1.5223348 ]
          [-1.208275   -1.1878313 ]
          [-1.4711051  -1.5200229 ]]]
        <NDArray 3x4x2 @cpu(0)>

    image = mx.nd.random.uniform(0, 1, (2, 3, 4, 2))
    normalize(image, mean=(0, 1, 2), std=(3, 2, 1))
        [[[[ 0.18934818  0.13092826]
           [ 0.3085322   0.27869293]
           [ 0.02367868  0.11246539]
           [ 0.0290431   0.2160573 ]]
          [[-0.4898908  -0.31587923]
           [-0.08369008 -0.02142242]
           [-0.11092162 -0.42982462]
           [-0.06499392 -0.06495637]]
          [[-1.0213816  -1.526392  ]
           [-1.2008414  -1.1990893 ]
           [-1.5385206  -1.4795225 ]
           [-1.2194707  -1.3211205 ]]]
         [[[ 0.03942481  0.24021089]
           [ 0.21330701  0.1940066 ]
           [ 0.04778443  0.17912441]
           [ 0.31488964  0.25287187]]
          [[-0.23907584 -0.4470462 ]
           [-0.29266903 -0.2631998 ]
           [-0.3677222  -0.40683383]
           [-0.11288315 -0.13154092]]
          [[-1.5438497  -1.7834496 ]
           [-1.431566   -1.8647819 ]
           [-1.9812102  -1.675859  ]
           [-1.3823645  -1.8503251 ]]]]
        <NDArray 2x3x4x2 @cpu(0)>

Defined in src/operator/image/image_random.cc:L167

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • mean::tuple of <float>, optional, default=[0,0,0,0]: Sequence of means for each channel. Default value is 0.
  • std::tuple of <float>, optional, default=[1,1,1,1]: Sequence of standard deviations for each channel. Default value is 1.

source

# MXNet.mx._image_random_brightnessMethod.

_image_random_brightness(data, min_factor, max_factor)

Defined in src/operator/image/image_random.cc:L215

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._image_random_color_jitterMethod.

_image_random_color_jitter(data, brightness, contrast, saturation, hue)

Defined in src/operator/image/image_random.cc:L246

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • brightness::float, required: How much to jitter brightness.
  • contrast::float, required: How much to jitter contrast.
  • saturation::float, required: How much to jitter saturation.
  • hue::float, required: How much to jitter hue.

source

# MXNet.mx._image_random_contrastMethod.

_image_random_contrast(data, min_factor, max_factor)

Defined in src/operator/image/image_random.cc:L222

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._image_random_flip_left_rightMethod.

_image_random_flip_left_right(data)

Defined in src/operator/image/image_random.cc:L200

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._image_random_flip_top_bottomMethod.

_image_random_flip_top_bottom(data)

Defined in src/operator/image/image_random.cc:L210

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._image_random_hueMethod.

_image_random_hue(data, min_factor, max_factor)

Defined in src/operator/image/image_random.cc:L238

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._image_random_lightingMethod.

_image_random_lighting(data, alpha_std)

Randomly add PCA noise. Follow the AlexNet style.

Defined in src/operator/image/image_random.cc:L262

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • alpha_std::float, optional, default=0.0500000007: Level of the lighting noise.

source

# MXNet.mx._image_random_saturationMethod.

_image_random_saturation(data, min_factor, max_factor)

Defined in src/operator/image/image_random.cc:L230

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._image_resizeMethod.

_image_resize(data, size, keep_ratio, interp)

Resize an image NDArray of shape (H x W x C) or (N x H x W x C) to the given size Example: .. code-block:: python image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) mx.nd.image.resize(image, (3, 3)) [[[124 111 197] [158 80 155] [193 50 112]]

         [[110 100 113]
          [134 165 148]
          [157 231 182]]

         [[202 176 134]
          [174 191 149]
          [147 207 164]]]
        <NDArray 3x3x3 @cpu(0)>
    image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
    mx.nd.image.resize(image, (2, 2))            
        [[[[ 59 133  80]
           [187 114 153]]

          [[ 38 142  39]
           [207 131 124]]]


          [[[117 125 136]
           [191 166 150]]

          [[129  63 113]
           [182 109  48]]]]
        <NDArray 2x2x2x3 @cpu(0)>

Defined in src/operator/image/resize.cc:L70

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • size::Shape(tuple), optional, default=[]: Size of new image. Could be (width, height) or (size)
  • keep_ratio::boolean, optional, default=0: Whether to resize the short edge or both edges to size, if size is give as an integer.
  • interp::int, optional, default='1': Interpolation method for resizing. By default uses bilinear interpolationOptions are INTERNEAREST - a nearest-neighbor interpolationINTERLINEAR - a bilinear interpolationINTERAREA - resampling using pixel area relationINTERCUBIC - a bicubic interpolation over 4x4 pixel neighborhoodINTER_LANCZOS4 - a Lanczos interpolation over 8x8 pixel neighborhoodNote that the GPU version only support bilinear interpolation(1)

source

# MXNet.mx._image_to_tensorMethod.

_image_to_tensor(data)

Converts an image NDArray of shape (H x W x C) or (N x H x W x C) with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W) with values in the range [0, 1]

Example: .. code-block:: python image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) to_tensor(image) [[[ 0.85490197 0.72156864] [ 0.09019608 0.74117649] [ 0.61960787 0.92941177] [ 0.96470588 0.1882353 ]] [[ 0.6156863 0.73725492] [ 0.46666667 0.98039216] [ 0.44705883 0.45490196] [ 0.01960784 0.8509804 ]] [[ 0.39607844 0.03137255] [ 0.72156864 0.52941179] [ 0.16470589 0.7647059 ] [ 0.05490196 0.70588237]]]

    image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
    to_tensor(image)
        [[[[0.11764706 0.5803922 ]
           [0.9411765  0.10588235]
           [0.2627451  0.73333335]
           [0.5647059  0.32156864]]
          [[0.7176471  0.14117648]
           [0.75686276 0.4117647 ]
           [0.18431373 0.45490196]
           [0.13333334 0.6156863 ]]
          [[0.6392157  0.5372549 ]
           [0.52156866 0.47058824]
           [0.77254903 0.21568628]
           [0.01568628 0.14901961]]]
         [[[0.6117647  0.38431373]
           [0.6784314  0.6117647 ]
           [0.69411767 0.96862745]
           [0.67058825 0.35686275]]
          [[0.21960784 0.9411765 ]
           [0.44705883 0.43529412]
           [0.09803922 0.6666667 ]
           [0.16862746 0.1254902 ]]
          [[0.6156863  0.9019608 ]
           [0.35686275 0.9019608 ]
           [0.05882353 0.6509804 ]
           [0.20784314 0.7490196 ]]]]
        <NDArray 2x3x4x2 @cpu(0)>

Defined in src/operator/image/image_random.cc:L92

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray

source

# MXNet.mx._imdecodeMethod.

_imdecode(mean, index, x0, y0, x1, y1, c, size)

Decode an image, clip to (x0, y0, x1, y1), subtract mean, and write to buffer

Arguments

  • mean::NDArray-or-SymbolicNode: image mean
  • index::int: buffer position for output
  • x0::int: x0
  • y0::int: y0
  • x1::int: x1
  • y1::int: y1
  • c::int: channel
  • size::int: length of str_img

source

# MXNet.mx._lesserMethod.

_lesser(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._lesser_equalMethod.

_lesser_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._lesser_equal_scalarMethod.

_lesser_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._lesser_scalarMethod.

_lesser_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._linalg_detMethod.

_linalg_det(A)

Compute the determinant of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

out = det(A)

If n>2, det is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only. .. note:: There is no gradient backwarded when A is non-invertible (which is equivalent to det(A) = 0) because zero is rarely hit upon in float point computation and the Jacobi's formula on determinant gradient is not computationally efficient when A is non-invertible.

Examples::

Single matrix determinant A = [[1., 4.], [2., 3.]] det(A) = [-5.]

Batch matrix determinant A = [[[1., 4.], [2., 3.]], [[2., 3.], [1., 4.]]] det(A) = [-5., 5.]

Defined in src/operator/tensor/la_op.cc:L974

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._linalg_extractdiagMethod.

_linalg_extractdiag(A, offset)

Extracts the diagonal entries of a square matrix. Input is a tensor A of dimension n >= 2.

If n=2, then A represents a single square matrix which diagonal elements get extracted as a 1-dimensional tensor.

If n>2, then A represents a batch of square matrices on the trailing two dimensions. The extracted diagonals are returned as an n-1-dimensional tensor.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix diagonal extraction
A = [[1.0, 2.0],
     [3.0, 4.0]]

extractdiag(A) = [1.0, 4.0]

extractdiag(A, 1) = [2.0]

Batch matrix diagonal extraction
A = [[[1.0, 2.0],
      [3.0, 4.0]],
     [[5.0, 6.0],
      [7.0, 8.0]]]

extractdiag(A) = [[1.0, 4.0],
                  [5.0, 8.0]]

Defined in src/operator/tensor/la_op.cc:L494

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrices
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.

source

# MXNet.mx._linalg_extracttrianMethod.

_linalg_extracttrian(A, offset, lower)

Extracts a triangular sub-matrix from a square matrix. Input is a tensor A of dimension n >= 2.

If n=2, then A represents a single square matrix from which a triangular sub-matrix is extracted as a 1-dimensional tensor.

If n>2, then A represents a batch of square matrices on the trailing two dimensions. The extracted triangular sub-matrices are returned as an n-1-dimensional tensor.

The offset and lower parameters determine the triangle to be extracted:

  • When offset = 0 either the lower or upper triangle with respect to the main diagonal is extracted depending on the value of parameter lower.
  • When offset = k > 0 the upper triangle with respect to the k-th diagonal above the main diagonal is extracted.
  • When offset = k < 0 the lower triangle with respect to the k-th diagonal below the main diagonal is extracted.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single triagonal extraction
A = [[1.0, 2.0],
     [3.0, 4.0]]

extracttrian(A) = [1.0, 3.0, 4.0]
extracttrian(A, lower=False) = [1.0, 2.0, 4.0]
extracttrian(A, 1) = [2.0]
extracttrian(A, -1) = [3.0]

Batch triagonal extraction
A = [[[1.0, 2.0],
      [3.0, 4.0]],
     [[5.0, 6.0],
      [7.0, 8.0]]]

extracttrian(A) = [[1.0, 3.0, 4.0],
                   [5.0, 7.0, 8.0]]

Defined in src/operator/tensor/la_op.cc:L604

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrices
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.
  • lower::boolean, optional, default=1: Refer to the lower triangular matrix if lower=true, refer to the upper otherwise. Only relevant when offset=0

source

# MXNet.mx._linalg_gelqfMethod.

_linalg_gelqf(A)

LQ factorization for general matrix. Input is a tensor A of dimension n >= 2.

If n=2, we compute the LQ factorization (LAPACK gelqf, followed by orglq). A must have shape (x, y) with x <= y, and must have full rank =x. The LQ factorization consists of L with shape (x, x) and Q with shape (x, y), so that:

A = L * Q

Here, L is lower triangular (upper triangle equal to zero) with nonzero diagonal, and Q is row-orthonormal, meaning that

Q * Q\ :sup:T

is equal to the identity matrix of shape (x, x).

If n>2, gelqf is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single LQ factorization A = [[1., 2., 3.], [4., 5., 6.]] Q, L = gelqf(A) Q = [[-0.26726124, -0.53452248, -0.80178373], [0.87287156, 0.21821789, -0.43643578]] L = [[-3.74165739, 0.], [-8.55235974, 1.96396101]]

Batch LQ factorization A = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]] Q, L = gelqf(A) Q = [[[-0.26726124, -0.53452248, -0.80178373], [0.87287156, 0.21821789, -0.43643578]], [[-0.50257071, -0.57436653, -0.64616234], [0.7620735, 0.05862104, -0.64483142]]] L = [[[-3.74165739, 0.], [-8.55235974, 1.96396101]], [[-13.92838828, 0.], [-19.09768702, 0.52758934]]]

Defined in src/operator/tensor/la_op.cc:L797

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices to be factorized

source

# MXNet.mx._linalg_gemmMethod.

_linalg_gemm(A, B, C, transpose_a, transpose_b, alpha, beta, axis)

Performs general matrix multiplication and accumulation. Input are tensors A, B, C, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, the BLAS3 function gemm is performed:

out = alpha * op\ (A) * op\ (B) + beta * C

Here, alpha and beta are scalar parameters, and op() is either the identity or matrix transposition (depending on transpose_a, transpose_b).

If n>2, gemm is performed separately for a batch of matrices. The column indices of the matrices are given by the last dimensions of the tensors, the row indices by the axis specified with the axis parameter. By default, the trailing two dimensions will be used for matrix encoding.

For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes calls. For example let A, B, C be 5 dimensional tensors. Then gemm(A, B, C, axis=1) is equivalent to the following without the overhead of the additional swapaxis operations::

A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = swapaxes(C, dim1=1, dim2=3)
C = gemm(A1, B1, C)
C = swapaxis(C, dim1=1, dim2=3)

When the input data is of type float32 and the environment variables MXNETCUDAALLOWTENSORCORE and MXNETCUDATENSOROPMATHALLOWCONVERSION are set to 1, this operator will try to use pseudo-float16 precision (float32 math with float16 I/O) precision in order to use Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix multiply-add A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]]

Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0) = [[[104.0]], [[0.14]]]

Defined in src/operator/tensor/la_op.cc:L88

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices
  • B::NDArray-or-SymbolicNode: Tensor of input matrices
  • C::NDArray-or-SymbolicNode: Tensor of input matrices
  • transpose_a::boolean, optional, default=0: Multiply with transposed of first input (A).
  • transpose_b::boolean, optional, default=0: Multiply with transposed of second input (B).
  • alpha::double, optional, default=1: Scalar factor multiplied with A*B.
  • beta::double, optional, default=1: Scalar factor multiplied with C.
  • axis::int, optional, default='-2': Axis corresponding to the matrix rows.

source

# MXNet.mx._linalg_gemm2Method.

_linalg_gemm2(A, B, transpose_a, transpose_b, alpha, axis)

Performs general matrix multiplication. Input are tensors A, B, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, the BLAS3 function gemm is performed:

out = alpha * op\ (A) * op\ (B)

Here alpha is a scalar parameter and op() is either the identity or the matrix transposition (depending on transpose_a, transpose_b).

If n>2, gemm is performed separately for a batch of matrices. The column indices of the matrices are given by the last dimensions of the tensors, the row indices by the axis specified with the axis parameter. By default, the trailing two dimensions will be used for matrix encoding.

For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes calls. For example let A, B be 5 dimensional tensors. Then gemm(A, B, axis=1) is equivalent to the following without the overhead of the additional swapaxis operations::

A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = gemm2(A1, B1)
C = swapaxis(C, dim1=1, dim2=3)

When the input data is of type float32 and the environment variables MXNETCUDAALLOWTENSORCORE and MXNETCUDATENSOROPMATHALLOWCONVERSION are set to 1, this operator will try to use pseudo-float16 precision (float32 math with float16 I/O) precision in order to use Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]

Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[[4.0]], [[0.04 ]]]

Defined in src/operator/tensor/la_op.cc:L162

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices
  • B::NDArray-or-SymbolicNode: Tensor of input matrices
  • transpose_a::boolean, optional, default=0: Multiply with transposed of first input (A).
  • transpose_b::boolean, optional, default=0: Multiply with transposed of second input (B).
  • alpha::double, optional, default=1: Scalar factor multiplied with A*B.
  • axis::int, optional, default='-2': Axis corresponding to the matrix row indices.

source

# MXNet.mx._linalg_inverseMethod.

_linalg_inverse(A)

Compute the inverse of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

out = A\ :sup:-1

If n>2, inverse is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix inverse A = [[1., 4.], [2., 3.]] inverse(A) = [[-0.6, 0.8], [0.4, -0.2]]

Batch matrix inverse A = [[[1., 4.], [2., 3.]], [[1., 3.], [2., 4.]]] inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]], [[-2., 1.5], [1., -0.5]]]

Defined in src/operator/tensor/la_op.cc:L919

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._linalg_makediagMethod.

_linalg_makediag(A, offset)

Constructs a square matrix with the input as diagonal. Input is a tensor A of dimension n >= 1.

If n=1, then A represents the diagonal entries of a single square matrix. This matrix will be returned as a 2-dimensional tensor. If n>1, then A represents a batch of diagonals of square matrices. The batch of diagonal matrices will be returned as an n+1-dimensional tensor.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single diagonal matrix construction
A = [1.0, 2.0]

makediag(A)    = [[1.0, 0.0],
                  [0.0, 2.0]]

makediag(A, 1) = [[0.0, 1.0, 0.0],
                  [0.0, 0.0, 2.0],
                  [0.0, 0.0, 0.0]]

Batch diagonal matrix construction
A = [[1.0, 2.0],
     [3.0, 4.0]]

makediag(A) = [[[1.0, 0.0],
                [0.0, 2.0]],
               [[3.0, 0.0],
                [0.0, 4.0]]]

Defined in src/operator/tensor/la_op.cc:L546

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of diagonal entries
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.

source

# MXNet.mx._linalg_maketrianMethod.

_linalg_maketrian(A, offset, lower)

Constructs a square matrix with the input representing a specific triangular sub-matrix. This is basically the inverse of linalg.extracttrian. Input is a tensor A of dimension n >= 1.

If n=1, then A represents the entries of a triangular matrix which is lower triangular if offset<0 or offset=0, lower=true. The resulting matrix is derived by first constructing the square matrix with the entries outside the triangle set to zero and then adding offset-times an additional diagonal with zero entries to the square matrix.

If n>1, then A represents a batch of triangular sub-matrices. The batch of corresponding square matrices is returned as an n+1-dimensional tensor.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single  matrix construction
A = [1.0, 2.0, 3.0]

maketrian(A)              = [[1.0, 0.0],
                             [2.0, 3.0]]

maketrian(A, lower=false) = [[1.0, 2.0],
                             [0.0, 3.0]]

maketrian(A, offset=1)    = [[0.0, 1.0, 2.0],
                             [0.0, 0.0, 3.0],
                             [0.0, 0.0, 0.0]]
maketrian(A, offset=-1)   = [[0.0, 0.0, 0.0],
                             [1.0, 0.0, 0.0],
                             [2.0, 3.0, 0.0]]

Batch matrix construction
A = [[1.0, 2.0, 3.0],
     [4.0, 5.0, 6.0]]

maketrian(A)           = [[[1.0, 0.0],
                           [2.0, 3.0]],
                          [[4.0, 0.0],
                           [5.0, 6.0]]]

maketrian(A, offset=1) = [[[0.0, 1.0, 2.0],
                           [0.0, 0.0, 3.0],
                           [0.0, 0.0, 0.0]],
                          [[0.0, 4.0, 5.0],
                           [0.0, 0.0, 6.0],
                           [0.0, 0.0, 0.0]]]

Defined in src/operator/tensor/la_op.cc:L672

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of triangular matrices stored as vectors
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.
  • lower::boolean, optional, default=1: Refer to the lower triangular matrix if lower=true, refer to the upper otherwise. Only relevant when offset=0

source

# MXNet.mx._linalg_potrfMethod.

_linalg_potrf(A)

Performs Cholesky factorization of a symmetric positive-definite matrix. Input is a tensor A of dimension n >= 2.

If n=2, the Cholesky factor B of the symmetric, positive definite matrix A is computed. B is triangular (entries of upper or lower triangle are all zero), has positive diagonal entries, and:

A = B * B\ :sup:T if lower = true A = B\ :sup:T * B if lower = false

If n>2, potrf is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix factorization A = [[4.0, 1.0], [1.0, 4.25]] potrf(A) = [[2.0, 0], [0.5, 2.0]]

Batch matrix factorization A = [[[4.0, 1.0], [1.0, 4.25]], [[16.0, 4.0], [4.0, 17.0]]] potrf(A) = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]]

Defined in src/operator/tensor/la_op.cc:L213

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices to be decomposed

source

# MXNet.mx._linalg_potriMethod.

_linalg_potri(A)

Performs matrix inversion from a Cholesky factorization. Input is a tensor A of dimension n >= 2.

If n=2, A is a triangular matrix (entries of upper or lower triangle are all zero) with positive diagonal. We compute:

out = A\ :sup:-T * A\ :sup:-1 if lower = true out = A\ :sup:-1 * A\ :sup:-T if lower = false

In other words, if A is the Cholesky factor of a symmetric positive definite matrix B (obtained by potrf), then

out = B\ :sup:-1

If n>2, potri is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

.. note:: Use this operator only if you are certain you need the inverse of B, and cannot use the Cholesky factor A (potrf), together with backsubstitution (trsm). The latter is numerically much safer, and also cheaper.

Examples::

Single matrix inverse A = [[2.0, 0], [0.5, 2.0]] potri(A) = [[0.26563, -0.0625], [-0.0625, 0.25]]

Batch matrix inverse A = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] potri(A) = [[[0.26563, -0.0625], [-0.0625, 0.25]], [[0.06641, -0.01562], [-0.01562, 0,0625]]]

Defined in src/operator/tensor/la_op.cc:L274

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of lower triangular matrices

source

# MXNet.mx._linalg_slogdetMethod.

_linalg_slogdet(A)

Compute the sign and log of the determinant of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

sign = sign(det(A)) logabsdet = log(abs(det(A)))

If n>2, slogdet is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only. .. note:: The gradient is not properly defined on sign, so the gradient of it is not backwarded. .. note:: No gradient is backwarded when A is non-invertible. Please see the docs of operator det for detail.

Examples::

Single matrix signed log determinant A = [[2., 3.], [1., 4.]] sign, logabsdet = slogdet(A) sign = [1.] logabsdet = [1.609438]

Batch matrix signed log determinant A = [[[2., 3.], [1., 4.]], [[1., 2.], [2., 4.]], [[1., 2.], [4., 3.]]] sign, logabsdet = slogdet(A) sign = [1., 0., -1.] logabsdet = [1.609438, -inf, 1.609438]

Defined in src/operator/tensor/la_op.cc:L1033

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._linalg_sumlogdiagMethod.

_linalg_sumlogdiag(A)

Computes the sum of the logarithms of the diagonal elements of a square matrix. Input is a tensor A of dimension n >= 2.

If n=2, A must be square with positive diagonal entries. We sum the natural logarithms of the diagonal elements, the result has shape (1,).

If n>2, sumlogdiag is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix reduction A = [[1.0, 1.0], [1.0, 7.0]] sumlogdiag(A) = [1.9459]

Batch matrix reduction A = [[[1.0, 1.0], [1.0, 7.0]], [[3.0, 0], [0, 17.0]]] sumlogdiag(A) = [1.9459, 3.9318]

Defined in src/operator/tensor/la_op.cc:L444

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrices

source

# MXNet.mx._linalg_syevdMethod.

_linalg_syevd(A)

Eigendecomposition for symmetric matrix. Input is a tensor A of dimension n >= 2.

If n=2, A must be symmetric, of shape (x, x). We compute the eigendecomposition, resulting in the orthonormal matrix U of eigenvectors, shape (x, x), and the vector L of eigenvalues, shape (x,), so that:

U * A = diag(L) * U

Here:

U * U\ :sup:T = U\ :sup:T * U = I

where I is the identity matrix. Also, L(0) <= L(1) <= L(2) <= ... (ascending order).

If n>2, syevd is performed separately on the trailing two dimensions of A (batch mode). In this case, U has n dimensions like A, and L has n-1 dimensions.

.. note:: The operator supports float32 and float64 data types only.

.. note:: Derivatives for this operator are defined only if A is such that all its eigenvalues are distinct, and the eigengaps are not too small. If you need gradients, do not apply this operator to matrices with multiple eigenvalues.

Examples::

Single symmetric eigendecomposition A = [[1., 2.], [2., 4.]] U, L = syevd(A) U = [[0.89442719, -0.4472136], [0.4472136, 0.89442719]] L = [0., 5.]

Batch symmetric eigendecomposition A = [[[1., 2.], [2., 4.]], [[1., 2.], [2., 5.]]] U, L = syevd(A) U = [[[0.89442719, -0.4472136], [0.4472136, 0.89442719]], [[0.92387953, -0.38268343], [0.38268343, 0.92387953]]] L = [[0., 5.], [0.17157288, 5.82842712]]

Defined in src/operator/tensor/la_op.cc:L867

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices to be factorized

source

# MXNet.mx._linalg_syrkMethod.

_linalg_syrk(A, transpose, alpha)

Multiplication of matrix with its transpose. Input is a tensor A of dimension n >= 2.

If n=2, the operator performs the BLAS3 function syrk:

out = alpha * A * A\ :sup:T

if transpose=False, or

out = alpha * A\ :sup:T \ * A

if transpose=True.

If n>2, syrk is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix multiply A = [[1., 2., 3.], [4., 5., 6.]] syrk(A, alpha=1., transpose=False) = [[14., 32.], [32., 77.]] syrk(A, alpha=1., transpose=True) = [[17., 22., 27.], [22., 29., 36.], [27., 36., 45.]]

Batch matrix multiply A = [[[1., 1.]], [[0.1, 0.1]]] syrk(A, alpha=2., transpose=False) = [[[4.]], [[0.04]]]

Defined in src/operator/tensor/la_op.cc:L729

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices
  • transpose::boolean, optional, default=0: Use transpose of input matrix.
  • alpha::double, optional, default=1: Scalar factor to be applied to the result.

source

# MXNet.mx._linalg_trmmMethod.

_linalg_trmm(A, B, transpose, rightside, lower, alpha)

Performs multiplication with a lower triangular matrix. Input are tensors A, B, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, A must be triangular. The operator performs the BLAS3 function trmm:

out = alpha * op\ (A) * B

if rightside=False, or

out = alpha * B * op\ (A)

if rightside=True. Here, alpha is a scalar parameter, and op() is either the identity or the matrix transposition (depending on transpose).

If n>2, trmm is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single triangular matrix multiply A = [[1.0, 0], [1.0, 1.0]] B = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] trmm(A, B, alpha=2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]

Batch triangular matrix multiply A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]] trmm(A, B, alpha=2.0) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]

Defined in src/operator/tensor/la_op.cc:L332

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of lower triangular matrices
  • B::NDArray-or-SymbolicNode: Tensor of matrices
  • transpose::boolean, optional, default=0: Use transposed of the triangular matrix
  • rightside::boolean, optional, default=0: Multiply triangular matrix from the right to non-triangular one.
  • lower::boolean, optional, default=1: True if the triangular matrix is lower triangular, false if it is upper triangular.
  • alpha::double, optional, default=1: Scalar factor to be applied to the result.

source

# MXNet.mx._linalg_trsmMethod.

_linalg_trsm(A, B, transpose, rightside, lower, alpha)

Solves matrix equation involving a lower triangular matrix. Input are tensors A, B, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, A must be triangular. The operator performs the BLAS3 function trsm, solving for out in:

op\ (A) * out = alpha * B

if rightside=False, or

out * op\ (A) = alpha * B

if rightside=True. Here, alpha is a scalar parameter, and op() is either the identity or the matrix transposition (depending on transpose).

If n>2, trsm is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix solve A = [[1.0, 0], [1.0, 1.0]] B = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] trsm(A, B, alpha=0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]

Batch matrix solve A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[4.0, 4.0, 4.0], [8.0, 8.0, 8.0]]] trsm(A, B, alpha=0.5) = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]

Defined in src/operator/tensor/la_op.cc:L395

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of lower triangular matrices
  • B::NDArray-or-SymbolicNode: Tensor of matrices
  • transpose::boolean, optional, default=0: Use transposed of the triangular matrix
  • rightside::boolean, optional, default=0: Multiply triangular matrix from the right to non-triangular one.
  • lower::boolean, optional, default=1: True if the triangular matrix is lower triangular, false if it is upper triangular.
  • alpha::double, optional, default=1: Scalar factor to be applied to the result.

source

# MXNet.mx._logical_andMethod.

_logical_and(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._logical_and_scalarMethod.

_logical_and_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._logical_orMethod.

_logical_or(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._logical_or_scalarMethod.

_logical_or_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._logical_xorMethod.

_logical_xor(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._logical_xor_scalarMethod.

_logical_xor_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._maximum_scalarMethod.

_maximum_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._minimum_scalarMethod.

_minimum_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._minus_scalarMethod.

_minus_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._mp_adamw_updateMethod.

_mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, beta1, beta2, epsilon, wd, eta, clip_gradient)

Update function for multi-precision AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v are 1st and 2nd order moment estimates (mean and variance).

.. math::

gt = \nabla J(W)\ mt = \beta1 m{t-1} + (1 - \beta1) gt\ vt = \beta2 v + (1 - \beta2) gt^2\ Wt = W - \etat (\alpha \frac{ mt }{ \sqrt{ vt } + \epsilon } + wd W)

It updates the weights using::

m = beta1m + (1-beta1)grad v = beta2v + (1-beta2)(grad**2) w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescalegrad * grad. If rescalegrad is NaN, Inf, or 0, the update is skipped.

Defined in src/operator/contrib/adamw.cc:L57

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mean::NDArray-or-SymbolicNode: Moving mean
  • var::NDArray-or-SymbolicNode: Moving variance
  • weight32::NDArray-or-SymbolicNode: Weight32
  • rescale_grad::NDArray-or-SymbolicNode: Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped.
  • lr::float, required: Learning rate
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • eta::float, required: Learning rate schedule multiplier
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx._mulMethod.

_mul(lhs, rhs)

mul is an alias of elemwisemul.

Multiplies arguments element-wise.

The storage type of $elemwise_mul$ output depends on storage types of inputs

  • elemwise_mul(default, default) = default
  • elemwisemul(rowsparse, rowsparse) = rowsparse
  • elemwisemul(default, rowsparse) = row_sparse
  • elemwisemul(rowsparse, default) = row_sparse
  • elemwise_mul(csr, csr) = csr
  • otherwise, $elemwise_mul$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._mul_scalarMethod.

_mul_scalar(data, scalar, is_int)

Multiply an array with a scalar.

$_mul_scalar$ only operates on data array of input if input is sparse.

For example, if input of shape (100, 100) has only 2 non zero elements, i.e. input.data = [5, 6], scalar = nan, it will result output.data = [nan, nan] instead of 10000 nans.

Defined in src/operator/tensor/elemwisebinaryscalaropbasic.cc:L152

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._multi_adamw_updateMethod.

_multi_adamw_update(data, lrs, beta1, beta2, epsilon, wds, etas, clip_gradient, num_weights)

Update function for AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v are 1st and 2nd order moment estimates (mean and variance).

.. math::

gt = \nabla J(W)\ mt = \beta1 m{t-1} + (1 - \beta1) gt\ vt = \beta2 v + (1 - \beta2) gt^2\ Wt = W - \etat (\alpha \frac{ mt }{ \sqrt{ vt } + \epsilon } + wd W)

It updates the weights using::

m = beta1m + (1-beta1)grad v = beta2v + (1-beta2)(grad**2) w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescalegrad * grad. If rescalegrad is NaN, Inf, or 0, the update is skipped.

Defined in src/operator/contrib/adamw.cc:L166

Arguments

  • data::NDArray-or-SymbolicNode[]: data
  • lrs::tuple of <float>, required: Learning rates
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wds::tuple of <float>, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • etas::tuple of <float>, required: Learning rates schedule multiplier
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx._multi_lamb_updateMethod.

_multi_lamb_update(data, learning_rates, beta1, beta2, epsilon, wds, rescale_grad, lower_bound, upper_bound, clip_gradient, bias_correction, step_count, num_tensors)

Compute the LAMB coefficients of multiple weights and grads"

Defined in src/operator/contrib/multi_lamb.cc:L175

Arguments

  • data::NDArray-or-SymbolicNode[]: data
  • learning_rates::tuple of <float>, required: List of learning rates
  • beta1::float, optional, default=0.899999976: Exponential decay rate for the first moment estimates.
  • beta2::float, optional, default=0.999000013: Exponential decay rate for the second moment estimates.
  • epsilon::float, optional, default=9.99999997e-07: Small value to avoid division by 0.
  • wds::tuple of <float>, required: List of Weight decays.Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Gradient rescaling factor
  • lower_bound::float, optional, default=-1: Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set
  • upper_bound::float, optional, default=-1: Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • bias_correction::boolean, optional, default=1: Whether to use bias correction.
  • step_count::Shape(tuple), required: Step count for each tensor
  • num_tensors::int, optional, default='1': Number of tensors

source

# MXNet.mx._multi_mp_adamw_updateMethod.

_multi_mp_adamw_update(data, lrs, beta1, beta2, epsilon, wds, etas, clip_gradient, num_weights)

Update function for multi-precision AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v are 1st and 2nd order moment estimates (mean and variance).

.. math::

gt = \nabla J(W)\ mt = \beta1 m{t-1} + (1 - \beta1) gt\ vt = \beta2 v + (1 - \beta2) gt^2\ Wt = W - \etat (\alpha \frac{ mt }{ \sqrt{ vt } + \epsilon } + wd W)

It updates the weights using::

m = beta1m + (1-beta1)grad v = beta2v + (1-beta2)(grad**2) w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescalegrad * grad. If rescalegrad is NaN, Inf, or 0, the update is skipped.

Defined in src/operator/contrib/adamw.cc:L222

Arguments

  • data::NDArray-or-SymbolicNode[]: data
  • lrs::tuple of <float>, required: Learning rates
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wds::tuple of <float>, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • etas::tuple of <float>, required: Learning rates schedule multiplier
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx._multi_mp_lamb_updateMethod.

_multi_mp_lamb_update(data, learning_rates, beta1, beta2, epsilon, wds, rescale_grad, lower_bound, upper_bound, clip_gradient, bias_correction, step_count, num_tensors)

Compute the LAMB coefficients of multiple weights and grads with Mix Precision"

Defined in src/operator/contrib/multi_lamb.cc:L213

Arguments

  • data::NDArray-or-SymbolicNode[]: data
  • learning_rates::tuple of <float>, required: List of learning rates
  • beta1::float, optional, default=0.899999976: Exponential decay rate for the first moment estimates.
  • beta2::float, optional, default=0.999000013: Exponential decay rate for the second moment estimates.
  • epsilon::float, optional, default=9.99999997e-07: Small value to avoid division by 0.
  • wds::tuple of <float>, required: List of Weight decays.Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Gradient rescaling factor
  • lower_bound::float, optional, default=-1: Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set
  • upper_bound::float, optional, default=-1: Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • bias_correction::boolean, optional, default=1: Whether to use bias correction.
  • step_count::Shape(tuple), required: Step count for each tensor
  • num_tensors::int, optional, default='1': Number of tensors

source

# MXNet.mx._not_equalMethod.

_not_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._not_equal_scalarMethod.

_not_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._np_allMethod.

_np_all(data, axis, keepdims)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.

source

# MXNet.mx._np_amaxMethod.

_np_amax(a, axis, keepdims, initial)

npamax is an alias of npmax.

Defined in src/operator/numpy/npbroadcastreduceopvalue.cc:L169

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.

source

# MXNet.mx._np_aminMethod.

_np_amin(a, axis, keepdims, initial)

npamin is an alias of npmin.

Defined in src/operator/numpy/npbroadcastreduceopvalue.cc:L198

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.

source

# MXNet.mx._np_anyMethod.

_np_any(data, axis, keepdims)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.

source

# MXNet.mx._np_atleast_1dMethod.

_np_atleast_1d(arys, num_args)

Note: npatleast1d takes variable number of positional inputs. So instead of calling as _npatleast1d([x, y, z], numargs=3), one should call via npatleast1d(x, y, z), and numargs will be determined automatically.

Arguments

  • arys::NDArray-or-SymbolicNode[]: List of input arrays
  • num_args::int, required: Number of input arrays.

source

# MXNet.mx._np_atleast_2dMethod.

_np_atleast_2d(arys, num_args)

Note: npatleast2d takes variable number of positional inputs. So instead of calling as _npatleast2d([x, y, z], numargs=3), one should call via npatleast2d(x, y, z), and numargs will be determined automatically.

Arguments

  • arys::NDArray-or-SymbolicNode[]: List of input arrays
  • num_args::int, required: Number of input arrays.

source

# MXNet.mx._np_atleast_3dMethod.

_np_atleast_3d(arys, num_args)

Note: npatleast3d takes variable number of positional inputs. So instead of calling as _npatleast3d([x, y, z], numargs=3), one should call via npatleast3d(x, y, z), and numargs will be determined automatically.

Arguments

  • arys::NDArray-or-SymbolicNode[]: List of input arrays
  • num_args::int, required: Number of input arrays.

source

# MXNet.mx._np_copyMethod.

_np_copy(a)

Return an array copy of the given object.

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L47

Arguments

  • a::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._np_cumsumMethod.

_np_cumsum(a, axis, dtype)

Return the cumulative sum of the elements along a given axis.

Defined in src/operator/numpy/np_cumsum.cc:L70

Arguments

  • a::NDArray-or-SymbolicNode: Input ndarray
  • axis::int or None, optional, default='None': Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
  • dtype::{None, 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.

source

# MXNet.mx._np_diagMethod.

_np_diag(data, k)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • k::int, optional, default='0': Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal.

source

# MXNet.mx._np_diagflatMethod.

_np_diagflat(data, k)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • k::int, optional, default='0': Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal.

source

# MXNet.mx._np_diagonalMethod.

_np_diagonal(data, offset, axis1, axis2)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • offset::int, optional, default='0': Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal. If input has shape (S0 S1) k must be between -S0 and S1
  • axis1::int, optional, default='0': The first axis of the sub-arrays of interest. Ignored when the input is a 1-D array.
  • axis2::int, optional, default='1': The second axis of the sub-arrays of interest. Ignored when the input is a 1-D array.

source

# MXNet.mx._np_dotMethod.

_np_dot(a, b)

Dot product of two arrays. Specifically,

  • If both a and b are 1-D arrays, it is inner product of vectors.
  • If both a and b are 2-D arrays, it is matrix multiplication.
  • If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred.
  • If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b.
  • If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b:

    Example ::

    dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])

Defined in src/operator/numpy/np_dot.cc:L121

Arguments

  • a::NDArray-or-SymbolicNode: First input
  • b::NDArray-or-SymbolicNode: Second input

source

# MXNet.mx._np_maxMethod.

_np_max(a, axis, keepdims, initial)

Defined in src/operator/numpy/npbroadcastreduceopvalue.cc:L169

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.

source

# MXNet.mx._np_minMethod.

_np_min(a, axis, keepdims, initial)

Defined in src/operator/numpy/npbroadcastreduceopvalue.cc:L198

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.

source

# MXNet.mx._np_moveaxisMethod.

_np_moveaxis(a, source, destination)

Move axes of an array to new positions. Other axes remain in their original order.

Defined in src/operator/numpy/npmatrixop.cc:L1263

Arguments

  • a::NDArray-or-SymbolicNode: Source input
  • source::Shape(tuple), required: Original positions of the axes to move. These must be unique.
  • destination::Shape(tuple), required: Destination positions for each of the original axes. These must also be unique.

source

# MXNet.mx._np_prodMethod.

_np_prod(axis, dtype, keepdims, initial, a)

Arguments

  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • dtype::{None, 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.
  • a::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._np_productMethod.

_np_product(axis, dtype, keepdims, initial, a)

npproduct is an alias of npprod.

Arguments

  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • dtype::{None, 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.
  • a::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._np_repeatMethod.

_np_repeat(data, repeats, axis)

nprepeat is an alias of repeat.

Repeats elements of an array. By default, $repeat$ flattens the input array into 1-D and then repeats the elements:: x = [[ 1, 2], [ 3, 4]] repeat(x, repeats=2) = [ 1., 1., 2., 2., 3., 3., 4., 4.] The parameter $axis$ specifies the axis along which to perform repeat:: repeat(x, repeats=2, axis=1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]] repeat(x, repeats=2, axis=0) = [[ 1., 2.], [ 1., 2.], [ 3., 4.], [ 3., 4.]] repeat(x, repeats=2, axis=-1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]]

Defined in src/operator/tensor/matrix_op.cc:L743

Arguments

  • data::NDArray-or-SymbolicNode: Input data array
  • repeats::int, required: The number of repetitions for each element.
  • axis::int or None, optional, default='None': The axis along which to repeat values. The negative numbers are interpreted counting from the backward. By default, use the flattened input array, and return a flat output array.

source

# MXNet.mx._np_reshapeMethod.

_np_reshape(a, newshape, order)

Defined in src/operator/numpy/npmatrixop.cc:L356

Arguments

  • a::NDArray-or-SymbolicNode: Array to be reshaped.
  • newshape::Shape(tuple), required: The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions.
  • order::string, optional, default='C': Read the elements of a using this index order, and place the elements into the reshaped array using this index order. 'C' means to read/write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest. Note that currently only C-like order is supported

source

# MXNet.mx._np_rollMethod.

_np_roll(data, shift, axis)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • shift::Shape or None, optional, default=None: The number of places by which elements are shifted. If a tuple,then axis must be a tuple of the same size, and each of the given axes is shiftedby the corresponding number. If an int while axis is a tuple of ints, then the same value is used for all given axes.
  • axis::Shape or None, optional, default=None: Axis or axes along which elements are shifted. By default, the array is flattenedbefore shifting, after which the original shape is restored.

source

# MXNet.mx._np_sometrueMethod.

_np_sometrue(data, axis, keepdims)

npsometrue is an alias of npany.

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.

source

# MXNet.mx._np_squeezeMethod.

_np_squeeze(a, axis)

Arguments

  • a::NDArray-or-SymbolicNode: data to squeeze
  • axis::Shape or None, optional, default=None: Selects a subset of the single-dimensional entries in the shape. If an axis is selected with shape entry greater than one, an error is raised.

source

# MXNet.mx._np_sumMethod.

_np_sum(a, axis, dtype, keepdims, initial)

Defined in src/operator/numpy/npbroadcastreduceopvalue.cc:L129

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • dtype::{None, 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.

source

# MXNet.mx._np_traceMethod.

_np_trace(data, offset, axis1, axis2)

Computes the sum of the diagonal elements of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, we sum the diagonal elements. The result has shape ().

If n>2, trace is performed separately on the matrix defined by axis1 and axis2 for all inputs (batch mode).

Examples::

// Single matrix reduction A = [[1.0, 1.0], [1.0, 7.0]] trace(A) = 8.0

// Batch matrix reduction A = [[[1.0, 1.0], [1.0, 7.0]], [[3.0, 0], [0, 17.0]]] trace(A) = [1.0, 18.0]

Defined in src/operator/numpy/nptraceop.cc:L74

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • offset::int, optional, default='0': Offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
  • axis1::int, optional, default='0': Axes to be used as the first axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to 0.
  • axis2::int, optional, default='1': Axes to be used as the second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to 1.

source

# MXNet.mx._np_transposeMethod.

_np_transpose(a, axes)

Arguments

  • a::NDArray-or-SymbolicNode: Source input
  • axes::Shape(tuple), optional, default=None: By default, reverse the dimensions, otherwise permute the axes according to the values given.

source

# MXNet.mx._npi_CustomMethod.

_npi_Custom(data, op_type)

npiCustom is an alias of Custom.

Apply a custom operator implemented in a frontend language (like Python).

Custom operators should override required methods like forward and backward. The custom operator must be registered before it can be used. Please check the tutorial here: https://mxnet.incubator.apache.org/api/faq/new_op

Defined in src/operator/custom/custom.cc:L546

Arguments

  • data::NDArray-or-SymbolicNode[]: Input data for the custom operator.
  • op_type::string: Name of the custom operator. This is the name that is passed to mx.operator.register to register the operator.

source

# MXNet.mx._npi_absMethod.

_npi_abs(x)

npiabs is an alias of npiabsolute.

Returns element-wise absolute value of the input. Example:: absolute([-2, 0, 3]) = [2, 0, 3]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L139

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_absoluteMethod.

_npi_absolute(x)

Returns element-wise absolute value of the input. Example:: absolute([-2, 0, 3]) = [2, 0, 3]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L139

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_addMethod.

_npi_add(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_add_scalarMethod.

_npi_add_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_arangeMethod.

_npi_arange(start, stop, step, repeat, infer_range, ctx, dtype)

Arguments

  • start::double, required: Start of interval. The interval includes this value. The default start value is 0.
  • stop::double or None, optional, default=None: End of interval. The interval does not include this value, except in some cases where step is not an integer and floating point round-off affects the length of out.
  • step::double, optional, default=1: Spacing between values.
  • repeat::int, optional, default='1': The repeating time of all elements. E.g repeat=3, the element a will be repeated three times –> a, a, a.
  • infer_range::boolean, optional, default=0: When set to True, infer the stop position from the start, step, repeat, and output tensor size.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._npi_arccosMethod.

_npi_arccos(x)

Returns element-wise inverse cosine of the input array. The input should be in range [-1, 1]. The output is in the closed interval :math:[0, \pi] .. math:: arccos([-1, -.707, 0, .707, 1]) = [\pi, 3\pi/4, \pi/2, \pi/4, 0] The storage type of $arccos$ output is always dense

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L355

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_arccoshMethod.

_npi_arccosh(x)

Returns the element-wise inverse hyperbolic cosine of the input array, computed element-wise.

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L417

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_arcsinMethod.

_npi_arcsin(x)

Returns element-wise inverse sine of the input array. .. math:: arcsin([-1, -.707, 0, .707, 1]) = [-\pi/2, -\pi/4, 0, \pi/4, \pi/2]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L344

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_arcsinhMethod.

_npi_arcsinh(x)

Returns the element-wise inverse hyperbolic sine of the input array, computed element-wise.

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L410

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_arctanMethod.

_npi_arctan(x)

Returns element-wise inverse tangent of the input array. .. math:: arctan([-1, 0, 1]) = [-\pi/4, 0, \pi/4]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L363

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_arctan2Method.

_npi_arctan2(x1, x2)

Arguments

  • x1::NDArray-or-SymbolicNode: The input array
  • x2::NDArray-or-SymbolicNode: The input array

source

# MXNet.mx._npi_arctan2_scalarMethod.

_npi_arctan2_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_arctanhMethod.

_npi_arctanh(x)

Returns the element-wise inverse hyperbolic tangent of the input array, computed element-wise.

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L424

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_argmaxMethod.

_npi_argmax(data, axis, keepdims)

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::int or None, optional, default='None': The axis along which to perform the reduction. Negative values means indexing from right to left. $Requires axis to be set as int, because global reduction is not supported yet.$
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axis is left in the result as dimension with size one.

source

# MXNet.mx._npi_argminMethod.

_npi_argmin(data, axis, keepdims)

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::int or None, optional, default='None': The axis along which to perform the reduction. Negative values means indexing from right to left. $Requires axis to be set as int, because global reduction is not supported yet.$
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axis is left in the result as dimension with size one.

source

# MXNet.mx._npi_argsortMethod.

_npi_argsort(data, axis, is_ascend, dtype)

npiargsort is an alias of argsort.

Returns the indices that would sort an input array along the given axis.

This function performs sorting along the given axis and returns an array of indices having same shape as an input array that index data in sorted order.

Examples::

x = [[ 0.3, 0.2, 0.4], [ 0.1, 0.3, 0.2]]

// sort along axis -1 argsort(x) = [[ 1., 0., 2.], [ 0., 2., 1.]]

// sort along axis 0 argsort(x, axis=0) = [[ 1., 0., 1.] [ 0., 1., 0.]]

// flatten and then sort argsort(x, axis=None) = [ 3., 1., 5., 0., 4., 2.]

Defined in src/operator/tensor/ordering_op.cc:L184

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • axis::int or None, optional, default='-1': Axis along which to sort the input tensor. If not given, the flattened array is used. Default is -1.
  • is_ascend::boolean, optional, default=1: Whether to sort in ascending or descending order.
  • dtype::{'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'},optional, default='float32': DType of the output indices. It is only valid when ret_typ is "indices" or "both". An error will be raised if the selected data type cannot precisely represent the indices.

source

# MXNet.mx._npi_aroundMethod.

_npi_around(x, decimals)

Arguments

  • x::NDArray-or-SymbolicNode: Input ndarray
  • decimals::int, optional, default='0': Number of decimal places to round to.

source

# MXNet.mx._npi_averageMethod.

_npi_average(a, weights, axis, returned, weighted)

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • weights::NDArray-or-SymbolicNode: The weights to calculate average
  • axis::Shape or None, optional, default=None: Axis or axes along which a average is performed. The default, axis=None, will average all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • returned::boolean, optional, default=0: If True, the tuple (average, sumofweights) is returned,otherwise only the average is returned.If weights=None, sumofweights is equivalent tothe number of elements over which the average is taken.
  • weighted::boolean, optional, default=1: Auxiliary flag to deal with none weights.

source

# MXNet.mx._npi_backward_ediff1dMethod.

_npi_backward_ediff1d()

Arguments

source

# MXNet.mx._npi_backward_nan_to_numMethod.

_npi_backward_nan_to_num()

Arguments

source

# MXNet.mx._npi_backward_polyvalMethod.

_npi_backward_polyval()

Arguments

source

# MXNet.mx._npi_bernoulliMethod.

_npi_bernoulli(input1, prob, logit, size, ctx, dtype, is_logit)

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • prob::float or None, required:
  • logit::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'bool', 'float16', 'float32', 'float64', 'int32', 'uint8'},optional, default='float32': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • is_logit::boolean, required:

source

# MXNet.mx._npi_bincountMethod.

_npi_bincount(data, weights, minlength, has_weights)

Arguments

  • data::NDArray-or-SymbolicNode: Data
  • weights::NDArray-or-SymbolicNode: Weights
  • minlength::int, optional, default='0': A minimum number of bins for the output arrayIf minlength is specified, there will be at least thisnumber of bins in the output array
  • has_weights::boolean, optional, default=0: Determine whether Bincount has weights.

source

# MXNet.mx._npi_bitwise_andMethod.

_npi_bitwise_and(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_bitwise_and_scalarMethod.

_npi_bitwise_and_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_bitwise_notMethod.

_npi_bitwise_not(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_bitwise_orMethod.

_npi_bitwise_or(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_bitwise_or_scalarMethod.

_npi_bitwise_or_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_bitwise_xorMethod.

_npi_bitwise_xor(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_bitwise_xor_scalarMethod.

_npi_bitwise_xor_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_blackmanMethod.

_npi_blackman(M, ctx, dtype)

Return the Blackman window.The Blackman window is a taper formed by using a weighted cosine.

Arguments

  • M::, optional, default=None: Number of points in the output window. If zero or less, an empty array is returned.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data-type of the returned array.

source

# MXNet.mx._npi_boolean_maskMethod.

_npi_boolean_mask(data, index, axis)

npibooleanmask is an alias of _contribboolean_mask.

Given an n-d NDArray data, and a 1-d NDArray index, the operator produces an un-predeterminable shaped n-d NDArray out, which stands for the rows in x where the corresonding element in index is non-zero.

data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) index = mx.nd.array([0, 1, 0]) out = mx.nd.contrib.boolean_mask(data, index) out

[[4. 5. 6.]]

Defined in src/operator/contrib/boolean_mask.cc:L195

Arguments

  • data::NDArray-or-SymbolicNode: Data
  • index::NDArray-or-SymbolicNode: Mask
  • axis::int, optional, default='0': An integer that represents the axis in NDArray to mask from.

source

# MXNet.mx._npi_boolean_mask_assign_scalarMethod.

_npi_boolean_mask_assign_scalar(data, mask, value, start_axis)

Scalar version of boolean assign

Defined in src/operator/numpy/npbooleanmask_assign.cc:L284

Arguments

  • data::NDArray-or-SymbolicNode: input
  • mask::NDArray-or-SymbolicNode: mask
  • value::float: value to be assigned to masked positions
  • start_axis::int: starting axis of boolean mask

source

# MXNet.mx._npi_boolean_mask_assign_tensorMethod.

_npi_boolean_mask_assign_tensor(data, mask, value, start_axis)

Tensor version of boolean assign

Defined in src/operator/numpy/npbooleanmask_assign.cc:L309

Arguments

  • data::NDArray-or-SymbolicNode: input
  • mask::NDArray-or-SymbolicNode: mask
  • value::NDArray-or-SymbolicNode: assignment
  • start_axis::int: starting axis of boolean mask

source

# MXNet.mx._npi_broadcast_toMethod.

_npi_broadcast_to(array, shape)

Arguments

  • array::NDArray-or-SymbolicNode: The input
  • shape::Shape(tuple), optional, default=[]: The shape of the desired array. We can set the dim to zero if it's same as the original. E.g A = broadcast_to(B, shape=(10, 0, 0)) has the same meaning as A = broadcast_axis(B, axis=0, size=10).

source

# MXNet.mx._npi_castMethod.

_npi_cast(data, dtype)

npicast is an alias of Cast.

Casts all elements of the input to a new type.

.. note:: $Cast$ is deprecated. Use $cast$ instead.

Example::

cast([0.9, 1.3], dtype='int32') = [0, 1] cast([1e20, 11.1], dtype='float16') = [inf, 11.09375] cast([300, 11.1, 10.9, -1, -3], dtype='uint8') = [44, 11, 10, 255, 253]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L664

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, required: Output data type.

source

# MXNet.mx._npi_cbrtMethod.

_npi_cbrt(x)

Return the cube-root of an array, element-wise. Example:: cbrt([1, 8, -125]) = [1, 2, -5]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L232

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_ceilMethod.

_npi_ceil(x)

Return the ceiling of the input, element-wise. The ceil of the scalar x is the smallest integer i, such that i >= x. Example:: ceil([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) = [-1., -1., -0., 1., 2., 2., 2.]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L165

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_choiceMethod.

_npi_choice(input1, input2, a, size, ctx, replace, weighted)

random choice

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • a::, required:
  • size::, required:
  • ctx::string, optional, default='cpu':
  • replace::boolean, optional, default=1:
  • weighted::boolean, optional, default=0:

source

# MXNet.mx._npi_choleskyMethod.

_npi_cholesky(A)

Defined in src/operator/numpy/linalg/np_potrf.cc:L46

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices to be decomposed

source

# MXNet.mx._npi_clipMethod.

_npi_clip(data, a_min, a_max)

npiclip is an alias of clip.

Clips (limits) the values in an array. Given an interval, values outside the interval are clipped to the interval edges. Clipping $x$ between a_min and a_max would be:: .. math:: clip(x, amin, amax) = \max(\min(x, amax), amin)) Example:: x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] clip(x,1,8) = [ 1., 1., 2., 3., 4., 5., 6., 7., 8., 8.] The storage type of $clip$ output depends on storage types of inputs and the amin, amax parameter values:

  • clip(default) = default
  • clip(rowsparse, amin <= 0, amax >= 0) = rowsparse
  • clip(csr, amin <= 0, amax >= 0) = csr
  • clip(rowsparse, amin < 0, a_max < 0) = default
  • clip(rowsparse, amin > 0, a_max > 0) = default
  • clip(csr, amin < 0, amax < 0) = csr
  • clip(csr, amin > 0, amax > 0) = csr

Defined in src/operator/tensor/matrix_op.cc:L676

Arguments

  • data::NDArray-or-SymbolicNode: Input array.
  • a_min::float, required: Minimum value
  • a_max::float, required: Maximum value

source

# MXNet.mx._npi_column_stackMethod.

_npi_column_stack(data, num_args)

Note: npicolumnstack takes variable number of positional inputs. So instead of calling as _npicolumnstack([x, y, z], numargs=3), one should call via npicolumnstack(x, y, z), and numargs will be determined automatically.

Defined in src/operator/numpy/npmatrixop.cc:L865

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to column_stack
  • num_args::int, required: Number of inputs to be column stacked

source

# MXNet.mx._npi_concatenateMethod.

_npi_concatenate(data, num_args, dim)

Note: npiconcatenate takes variable number of positional inputs. So instead of calling as npiconcatenate([x, y, z], numargs=3), one should call via _npiconcatenate(x, y, z), and num_args will be determined automatically.

Join a sequence of arrays along an existing axis.

Defined in src/operator/numpy/npmatrixop.cc:L677

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx._npi_copysignMethod.

_npi_copysign(lhs, rhs)

Defined in src/operator/numpy/npelemwisebroadcastopextended.cc:L47

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_copysign_scalarMethod.

_npi_copysign_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_copytoMethod.

_npi_copyto(data)

npicopyto is an alias of _copyto.

Arguments

  • data::NDArray: input data

source

# MXNet.mx._npi_cosMethod.

_npi_cos(x)

Computes the element-wise cosine of the input array. .. math:: cos([0, \pi/4, \pi/2]) = [1, 0.707, 0]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L328

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_coshMethod.

_npi_cosh(x)

Returns the hyperbolic cosine of the input array, computed element-wise. .. math:: cosh(x) = 0.5\times(exp(x) + exp(-x))

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L395

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_cvimdecodeMethod.

_npi_cvimdecode(buf, flag, to_rgb)

npicvimdecode is an alias of _cvimdecode.

Decode image with OpenCV. Note: return image in RGB by default, instead of OpenCV's default BGR.

Arguments

  • buf::NDArray: Buffer containing binary encoded image
  • flag::int, optional, default='1': Convert decoded image to grayscale (0) or color (1).
  • to_rgb::boolean, optional, default=1: Whether to convert decoded image to mxnet's default RGB format (instead of opencv's default BGR).

source

# MXNet.mx._npi_cvimreadMethod.

_npi_cvimread(filename, flag, to_rgb)

npicvimread is an alias of _cvimread.

Read and decode image with OpenCV. Note: return image in RGB by default, instead of OpenCV's default BGR.

Arguments

  • filename::string, required: Name of the image file to be loaded.
  • flag::int, optional, default='1': Convert decoded image to grayscale (0) or color (1).
  • to_rgb::boolean, optional, default=1: Whether to convert decoded image to mxnet's default RGB format (instead of opencv's default BGR).

source

# MXNet.mx._npi_cvimresizeMethod.

_npi_cvimresize(src, w, h, interp)

npicvimresize is an alias of _cvimresize.

Resize image with OpenCV.

Arguments

  • src::NDArray: source image
  • w::int, required: Width of resized image.
  • h::int, required: Height of resized image.
  • interp::int, optional, default='1': Interpolation method (default=cv2.INTER_LINEAR).

source

# MXNet.mx._npi_degreesMethod.

_npi_degrees(x)

Converts each element of the input array from radians to degrees. .. math:: degrees([0, \pi/2, \pi, 3\pi/2, 2\pi]) = [0, 90, 180, 270, 360]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L371

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_deleteMethod.

_npi_delete(arr, obj, start, stop, step, int_ind, axis)

Delete values along the given axis before the given indices.

Defined in src/operator/numpy/npdeleteop.cc:L71

Arguments

  • arr::NDArray-or-SymbolicNode: Input ndarray
  • obj::NDArray-or-SymbolicNode: Input ndarray
  • start::int or None, optional, default='None': If 'obj' is slice, 'start' is one of it's arguments.
  • stop::int or None, optional, default='None': If 'obj' is slice, 'stop' is one of it's arguments.
  • step::int or None, optional, default='None': If 'obj' is slice, 'step' is one of it's arguments.
  • int_ind::int or None, optional, default='None': If 'obj' is int, 'int_ind' is the index before which'values' is inserted
  • axis::int or None, optional, default='None': Axis along which to insert values.

source

# MXNet.mx._npi_detMethod.

_npi_det(A)

npidet is an alias of linalgdet.

Compute the determinant of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

out = det(A)

If n>2, det is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only. .. note:: There is no gradient backwarded when A is non-invertible (which is equivalent to det(A) = 0) because zero is rarely hit upon in float point computation and the Jacobi's formula on determinant gradient is not computationally efficient when A is non-invertible.

Examples::

Single matrix determinant A = [[1., 4.], [2., 3.]] det(A) = [-5.]

Batch matrix determinant A = [[[1., 4.], [2., 3.]], [[2., 3.], [1., 4.]]] det(A) = [-5., 5.]

Defined in src/operator/tensor/la_op.cc:L974

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._npi_diag_indices_fromMethod.

_npi_diag_indices_from(data)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray

source

# MXNet.mx._npi_diffMethod.

_npi_diff(a, n, axis)

Arguments

  • a::NDArray-or-SymbolicNode: Input ndarray
  • n::int, optional, default='1': The number of times values are differenced. If zero, the input is returned as-is.
  • axis::int, optional, default='-1': Axis along which the cumulative sum is computed. The default (None) is to compute the diff over the flattened array.

source

# MXNet.mx._npi_dsplitMethod.

_npi_dsplit(data, indices, axis, squeeze_axis, sections)

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • indices::Shape(tuple), required: Indices of splits. The elements should denote the boundaries of at which split is performed along the axis.
  • axis::int, optional, default='1': Axis along which to split.
  • squeeze_axis::boolean, optional, default=0: If true, Removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $true$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to $true$ only if $input.shape[axis] == num_outputs$.
  • sections::int, optional, default='0': Number of sections if equally splitted. Default to 0 which means split by indices.

source

# MXNet.mx._npi_dstackMethod.

_npi_dstack(data, num_args, dim)

Note: npidstack takes variable number of positional inputs. So instead of calling as npidstack([x, y, z], numargs=3), one should call via _npidstack(x, y, z), and num_args will be determined automatically.

Stack tensors in sequence depthwise (in third dimension)

Defined in src/operator/numpy/npmatrixop.cc:L1080

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx._npi_ediff1dMethod.

_npi_ediff1d(input1, input2, input3, to_begin_arr_given, to_end_arr_given, to_begin_scalar, to_end_scalar)

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • input3::NDArray-or-SymbolicNode: Source input
  • to_begin_arr_given::boolean, optional, default=0: To determine whether the to_begin parameter is an array.
  • to_end_arr_given::boolean, optional, default=0: To determine whether the to_end parameter is an array.
  • to_begin_scalar::double or None, optional, default=None: If the to_beginis a scalar, the value of this parameter.
  • to_end_scalar::double or None, optional, default=None: If the to_endis a scalar, the value of this parameter.

source

# MXNet.mx._npi_eigMethod.

_npi_eig(A)

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._npi_eighMethod.

_npi_eigh(A, UPLO)

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of real matrices
  • UPLO::, optional, default=L: Specifies whether the calculation is done with the lower or upper triangular part.

source

# MXNet.mx._npi_eigvalsMethod.

_npi_eigvals(A)

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._npi_eigvalshMethod.

_npi_eigvalsh(A, UPLO)

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix
  • UPLO::, optional, default=L: Specifies whether the calculation is done with the lower or upper triangular part.

source

# MXNet.mx._npi_einsumMethod.

_npi_einsum(data, num_args, subscripts, optimize)

Note: npieinsum takes variable number of positional inputs. So instead of calling as npieinsum([x, y, z], numargs=3), one should call via _npieinsum(x, y, z), and num_args will be determined automatically.

Defined in src/operator/numpy/npeinsumop.cc:L314

Arguments

  • data::NDArray-or-SymbolicNode[]: List of eimsum operands
  • num_args::int, required: Number of input arrays.
  • subscripts::string, optional, default='': Specifies the subscripts for summation as comma separated list of subscript labels. An implicit (classical Einstein summation) calculation is performed unless the explicit indicator '->' is included as well as subscript labels of the precise output form.
  • optimize::int, optional, default='0':

source

# MXNet.mx._npi_equalMethod.

_npi_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_equal_scalarMethod.

_npi_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: First input to the function
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_expMethod.

_npi_exp(x)

Calculate the exponential of all elements in the input array. Example:: exp([0, 1, 2]) = [1., 2.71828175, 7.38905621]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L240

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_expand_dimsMethod.

_npi_expand_dims(data, axis)

npiexpanddims is an alias of expanddims.

Inserts a new axis of size 1 into the array shape For example, given $x$ with shape $(2,3,4)$, then $expand_dims(x, axis=1)$ will return a new array with shape $(2,1,3,4)$.

Defined in src/operator/tensor/matrix_op.cc:L394

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • axis::int, required: Position where new axis is to be inserted. Suppose that the input NDArray's dimension is ndim, the range of the inserted axis is [-ndim, ndim]

source

# MXNet.mx._npi_expm1Method.

_npi_expm1(x)

Calculate $exp(x) - 1$ for all elements in the array.

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L287

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_exponentialMethod.

_npi_exponential(input1, scale, size, ctx)

Numpy behavior exponential

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • scale::float or None, optional, default=1:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._npi_eyeMethod.

_npi_eye(N, M, k, ctx, dtype)

Return a 2-D array with ones on the diagonal and zeros elsewhere.

Arguments

  • N::long, required: Number of rows in the output.
  • M::, optional, default=None: Number of columns in the output. If None, defaults to N.
  • k::long, optional, default=0: Index of the diagonal. 0 (the default) refers to the main diagonal,a positive value refers to an upper diagonal.and a negative value to a lower diagonal.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data-type of the returned array.

source

# MXNet.mx._npi_fixMethod.

_npi_fix(x)

Round to nearest integer towards zero. Round an array of floats element-wise to nearest integer towards zero. The rounded values are returned as floats. Example:: fix([-2.1, -1.9, 1.9, 2.1]) = [-2., -1., 1., 2.]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L208

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_flipMethod.

_npi_flip(data, axis)

Arguments

  • data::NDArray-or-SymbolicNode: Input data array
  • axis::Shape(tuple), required: The axis which to flip elements.

source

# MXNet.mx._npi_floorMethod.

_npi_floor(x)

Return the floor of the input, element-wise. The floor of the scalar x is the largest integer i, such that i <= x. Example:: floor([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) = [-2., -2., -1., 0., 1., 1., 2.]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L174

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_fullMethod.

_npi_full(shape, ctx, dtype, value)

npifull is an alias of _full.

fill target with a scalar value

Arguments

  • shape::Shape(tuple), optional, default=None: The shape of the output
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.
  • value::double, required: Value with which to fill newly created tensor

source

# MXNet.mx._npi_full_likeMethod.

_npi_full_like(a, fill_value, ctx, dtype)

Arguments

  • a::NDArray-or-SymbolicNode: The shape and data-type of a define these same attributes of the returned array.
  • fill_value::double, required: Value with which to fill newly created tensor
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{None, 'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='None': Target data type.

source

# MXNet.mx._npi_gammaMethod.

_npi_gamma(input1, input2, shape, scale, size, ctx, dtype)

Numpy behavior gamma

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • shape::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format xpu|xpu|xpu_pinned. Only used for imperative calls.
  • dtype::{'float16', 'float32', 'float64'},optional, default='float32': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._npi_gather_ndMethod.

_npi_gather_nd(data, indices)

npigathernd is an alias of gathernd.

Gather elements or slices from data and store to a tensor whose shape is defined by indices.

Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with shape (M, Y_0, ..., Y_{K-1}), the output will have shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, output shape will simply be (Y_0, ..., Y_{K-1}).

The elements in output is defined as follows::

output[y0, ..., y, xM, ..., x] = data[indices[0, y0, ..., y], ..., indices[M-1, y0, ..., y], xM, ..., x]

Examples::

data = [[0, 1], [2, 3]] indices = [[1, 1, 0], [0, 1, 0]] gather_nd(data, indices) = [2, 3, 0]

data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] indices = [[0, 1], [1, 0]] gather_nd(data, indices) = [[3, 4], [5, 6]]

Arguments

  • data::NDArray-or-SymbolicNode: data
  • indices::NDArray-or-SymbolicNode: indices

source

# MXNet.mx._npi_greaterMethod.

_npi_greater(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_greater_equalMethod.

_npi_greater_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_greater_equal_scalarMethod.

_npi_greater_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: First input to the function
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_greater_scalarMethod.

_npi_greater_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: First input to the function
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_gumbelMethod.

_npi_gumbel(input1, input2, loc, scale, size, ctx)

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • loc::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._npi_hammingMethod.

_npi_hamming(M, ctx, dtype)

Return the Hamming window.The Hamming window is a taper formed by using a weighted cosine.

Arguments

  • M::, optional, default=None: Number of points in the output window. If zero or less, an empty array is returned.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data-type of the returned array.

source

# MXNet.mx._npi_hanningMethod.

_npi_hanning(M, ctx, dtype)

Return the Hanning window.The Hanning window is a taper formed by using a weighted cosine.

Arguments

  • M::, optional, default=None: Number of points in the output window. If zero or less, an empty array is returned.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data-type of the returned array.

source

# MXNet.mx._npi_histogramMethod.

_npi_histogram(data, bins, bin_cnt, range)

npihistogram is an alias of _histogram.

This operators implements the histogram function.

Example:: x = [[0, 1], [2, 2], [3, 4]] histo, binedges = histogram(data=x, binbounds=[], bincnt=5, range=(0,5)) histo = [1, 1, 2, 1, 1] binedges = [0., 1., 2., 3., 4.] histo, binedges = histogram(data=x, binbounds=[0., 2.1, 3.]) histo = [4, 1]

Defined in src/operator/tensor/histogram.cc:L137

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • bins::NDArray-or-SymbolicNode: Input ndarray
  • bin_cnt::int or None, optional, default='None': Number of bins for uniform case
  • range::, optional, default=None: The lower and upper range of the bins. if not provided, range is simply (a.min(), a.max()). values outside the range are ignored. the first element of the range must be less than or equal to the second. range affects the automatic bin computation as well. while bin width is computed to be optimal based on the actual data within range, the bin count will fill the entire range including portions containing no data.

source

# MXNet.mx._npi_hsplitMethod.

_npi_hsplit(data, indices, axis, squeeze_axis, sections)

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • indices::Shape(tuple), required: Indices of splits. The elements should denote the boundaries of at which split is performed along the axis.
  • axis::int, optional, default='1': Axis along which to split.
  • squeeze_axis::boolean, optional, default=0: If true, Removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $true$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to $true$ only if $input.shape[axis] == num_outputs$.
  • sections::int, optional, default='0': Number of sections if equally splitted. Default to 0 which means split by indices.

source

# MXNet.mx._npi_hsplit_backwardMethod.

_npi_hsplit_backward()

Arguments

source

# MXNet.mx._npi_hstackMethod.

_npi_hstack(data, num_args, dim)

Note: npihstack takes variable number of positional inputs. So instead of calling as npihstack([x, y, z], numargs=3), one should call via _npihstack(x, y, z), and num_args will be determined automatically.

Stack tensors horizontally (in second dimension)

Defined in src/operator/numpy/npmatrixop.cc:L1042

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx._npi_hypotMethod.

_npi_hypot(x1, x2)

Arguments

  • x1::NDArray-or-SymbolicNode: The input array
  • x2::NDArray-or-SymbolicNode: The input array

source

# MXNet.mx._npi_hypot_scalarMethod.

_npi_hypot_scalar(data, scalar, is_int)

npihypotscalar is an alias of _hypotscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_identityMethod.

_npi_identity(shape, ctx, dtype)

Return a new identity array of given shape, type, and context.

Arguments

  • shape::Shape(tuple), optional, default=[]: The shape of the output
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._npi_indicesMethod.

_npi_indices(dimensions, dtype, ctx)

Return an array representing the indices of a grid.

Arguments

  • dimensions::Shape(tuple), required: The shape of the grid.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='int32': Target data type.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.

source

# MXNet.mx._npi_insert_scalarMethod.

_npi_insert_scalar(arr, values, val, start, stop, step, int_ind, axis)

Insert values along the given axis before the given indices.

Defined in src/operator/numpy/npinsertop_scalar.cc:L105

Arguments

  • arr::NDArray-or-SymbolicNode: Input ndarray
  • values::NDArray-or-SymbolicNode: Input ndarray
  • val::double or None, optional, default=None: A scaler to be inserted into 'array'
  • start::int or None, optional, default='None': If 'obj' is slice, 'start' is one of it's arguments.
  • stop::int or None, optional, default='None': If 'obj' is slice, 'stop' is one of it's arguments.
  • step::int or None, optional, default='None': If 'obj' is slice, 'step' is one of it's arguments.
  • int_ind::int or None, optional, default='None': If 'obj' is int, 'int_ind' is the index before which'values' is inserted
  • axis::int or None, optional, default='None': Axis along which to insert 'values'.

source

# MXNet.mx._npi_insert_sliceMethod.

_npi_insert_slice(arr, values, val, start, stop, step, int_ind, axis)

Insert values along the given axis before the given indices.

Defined in src/operator/numpy/npinsertop_slice.cc:L131

Arguments

  • arr::NDArray-or-SymbolicNode: Input ndarray
  • values::NDArray-or-SymbolicNode: Input ndarray
  • val::double or None, optional, default=None: A scaler to be inserted into 'array'
  • start::int or None, optional, default='None': If 'obj' is slice, 'start' is one of it's arguments.
  • stop::int or None, optional, default='None': If 'obj' is slice, 'stop' is one of it's arguments.
  • step::int or None, optional, default='None': If 'obj' is slice, 'step' is one of it's arguments.
  • int_ind::int or None, optional, default='None': If 'obj' is int, 'int_ind' is the index before which'values' is inserted
  • axis::int or None, optional, default='None': Axis along which to insert 'values'.

source

# MXNet.mx._npi_insert_tensorMethod.

_npi_insert_tensor(arr, values, obj, val, start, stop, step, int_ind, axis)

Insert values along the given axis before the given indices. Indices is tensor and ndim > 0.

Defined in src/operator/numpy/npinsertop_tensor.cc:L121

Arguments

  • arr::NDArray-or-SymbolicNode: Input ndarray
  • values::NDArray-or-SymbolicNode: Input ndarray
  • obj::NDArray-or-SymbolicNode: Input ndarray
  • val::double or None, optional, default=None: A scaler to be inserted into 'array'
  • start::int or None, optional, default='None': If 'obj' is slice, 'start' is one of it's arguments.
  • stop::int or None, optional, default='None': If 'obj' is slice, 'stop' is one of it's arguments.
  • step::int or None, optional, default='None': If 'obj' is slice, 'step' is one of it's arguments.
  • int_ind::int or None, optional, default='None': If 'obj' is int, 'int_ind' is the index before which'values' is inserted
  • axis::int or None, optional, default='None': Axis along which to insert 'values'.

source

# MXNet.mx._npi_invMethod.

_npi_inv(A)

npiinv is an alias of linalginverse.

Compute the inverse of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

out = A\ :sup:-1

If n>2, inverse is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix inverse A = [[1., 4.], [2., 3.]] inverse(A) = [[-0.6, 0.8], [0.4, -0.2]]

Batch matrix inverse A = [[[1., 4.], [2., 3.]], [[1., 3.], [2., 4.]]] inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]], [[-2., 1.5], [1., -0.5]]]

Defined in src/operator/tensor/la_op.cc:L919

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._npi_isfiniteMethod.

_npi_isfinite(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_isinfMethod.

_npi_isinf(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_isnanMethod.

_npi_isnan(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_isneginfMethod.

_npi_isneginf(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_isposinfMethod.

_npi_isposinf(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_lcmMethod.

_npi_lcm(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_lcm_scalarMethod.

_npi_lcm_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_ldexpMethod.

_npi_ldexp(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_ldexp_scalarMethod.

_npi_ldexp_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_lessMethod.

_npi_less(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_less_equalMethod.

_npi_less_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_less_equal_scalarMethod.

_npi_less_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: First input to the function
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_less_scalarMethod.

_npi_less_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: First input to the function
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_linspaceMethod.

_npi_linspace(start, stop, step, repeat, infer_range, ctx, dtype)

npilinspace is an alias of _linspace.

Return evenly spaced numbers over a specified interval. Similar to Numpy

Arguments

  • start::double, required: Start of interval. The interval includes this value. The default start value is 0.
  • stop::double or None, optional, default=None: End of interval. The interval does not include this value, except in some cases where step is not an integer and floating point round-off affects the length of out.
  • step::double, optional, default=1: Spacing between values.
  • repeat::int, optional, default='1': The repeating time of all elements. E.g repeat=3, the element a will be repeated three times –> a, a, a.
  • infer_range::boolean, optional, default=0: When set to True, infer the stop position from the start, step, repeat, and output tensor size.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._npi_logMethod.

_npi_log(x)

Returns element-wise Natural logarithmic value of the input. The natural logarithm is logarithm in base e, so that $log(exp(x)) = x$

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L247

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_log10Method.

_npi_log10(x)

Returns element-wise Base-10 logarithmic value of the input. $10**log10(x) = x$

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L268

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_log1pMethod.

_npi_log1p(x)

Return the natural logarithm of one plus the input array, element-wise. Calculates $log(1 + x)$.

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L282

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_log2Method.

_npi_log2(x)

Returns element-wise Base-2 logarithmic value of the input. $2**log2(x) = x$

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L275

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_logical_notMethod.

_npi_logical_not(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_logisticMethod.

_npi_logistic(input1, input2, loc, scale, size, ctx)

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • loc::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._npi_logspaceMethod.

_npi_logspace(start, stop, num, endpoint, base, ctx, dtype)

Return numbers spaced evenly on a log scale.

Arguments

  • start::double, required: The starting value of the sequence.
  • stop::double, required: The ending value of the sequence
  • num::int, required: Number of samples to generate. Must be non-negative.
  • endpoint::boolean, optional, default=1: If True, stop is the last sample. Otherwise, it is not included.
  • base::double, optional, default=10: The base of the log space. The step size between the elements in ln(samples) / ln(base) (or log_base(samples)) is uniform.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._npi_matmulMethod.

_npi_matmul(a, b)

Defined in src/operator/numpy/npmatmulop.cc:L140

Arguments

  • a::NDArray-or-SymbolicNode: First input
  • b::NDArray-or-SymbolicNode: Second input

source

# MXNet.mx._npi_maximumMethod.

_npi_maximum(lhs, rhs)

npimaximum is an alias of broadcast_maximum.

Returns element-wise maximum of the input arrays with broadcasting.

This function compares two input arrays and returns a new array having the element-wise maxima.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcast_maximum(x, y) = [[ 1., 1., 1.], [ 1., 1., 1.]]

Defined in src/operator/tensor/elemwisebinarybroadcastopextended.cc:L80

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_maximum_scalarMethod.

_npi_maximum_scalar(data, scalar, is_int)

npimaximumscalar is an alias of _maximumscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_meanMethod.

_npi_mean(a, axis, dtype, keepdims, initial)

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • dtype::{None, 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • initial::double or None, optional, default=None: Starting value for the sum.

source

# MXNet.mx._npi_minimumMethod.

_npi_minimum(lhs, rhs)

npiminimum is an alias of broadcast_minimum.

Returns element-wise minimum of the input arrays with broadcasting.

This function compares two input arrays and returns a new array having the element-wise minima.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcast_maximum(x, y) = [[ 0., 0., 0.], [ 1., 1., 1.]]

Defined in src/operator/tensor/elemwisebinarybroadcastopextended.cc:L116

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_minimum_scalarMethod.

_npi_minimum_scalar(data, scalar, is_int)

npiminimumscalar is an alias of _minimumscalar.

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_modMethod.

_npi_mod(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_mod_scalarMethod.

_npi_mod_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_multinomialMethod.

_npi_multinomial(a, n, pvals, size)

Draw samples from a multinomial distribution. " "The multinomial distribution is a multivariate generalisation of the binomial distribution. " "Take an experiment with one of p possible outcomes. " "An example of such an experiment is throwing a dice, where the outcome can be 1 through 6. " "Each sample drawn from the distribution represents n such experiments. " "Its values, Xi = [X0, X1, ..., Xp], represent the number of times the outcome was i.

Arguments

  • a::NDArray-or-SymbolicNode: Source input
  • n::int, required: Number of experiments.
  • pvals::, optional, default=None: Probabilities of each of the p different outcomes. These should sum to 1 (however, the last element is always assumed to account for the remaining probability, as long as sum(pvals[:-1]) <= 1)Note that this is for internal usage only. This operator will only have either input mx.ndarray or this list of pvals
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.

source

# MXNet.mx._npi_multiplyMethod.

_npi_multiply(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_multiply_scalarMethod.

_npi_multiply_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_nan_to_numMethod.

_npi_nan_to_num(data, copy, nan, posinf, neginf)

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L464

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • copy::boolean, optional, default=1: Whether to create a copy of x (True) or to replace valuesin-place (False). The in-place operation only occurs ifcasting to an array does not require a copy.Default is True.
  • nan::double, optional, default=0: Value to be used to fill NaN values. If no value is passedthen NaN values will be replaced with 0.0.
  • posinf::double or None, optional, default=None: Value to be used to fill positive infinity values.If no value is passed then positive infinity values will bereplaced with a very large number.
  • neginf::double or None, optional, default=None: Value to be used to fill negative infinity values.If no value is passed then negative infinity valueswill be replaced with a very small (or negative) number.

source

# MXNet.mx._npi_negativeMethod.

_npi_negative(x)

Numerical negative, element-wise. Example:: negative([1., -1.]) = [-1., 1.]

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_nonzeroMethod.

_npi_nonzero(x)

npinonzero is an alias of npxnonzero.

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_normMethod.

_npi_norm(data)

Defined in src/operator/numpy/linalg/npnormforward.cc:L31

Arguments

  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._npi_normalMethod.

_npi_normal(input1, input2, loc, scale, size, ctx, dtype)

Numpy behavior normal

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • loc::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'float16', 'float32', 'float64'},optional, default='float32': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._npi_normal_nMethod.

_npi_normal_n(input1, input2, loc, scale, size, ctx, dtype)

Ndarray behavior normal

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • loc::float or None, required:
  • scale::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'float16', 'float32', 'float64'},optional, default='float32': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._npi_not_equalMethod.

_npi_not_equal(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_not_equal_scalarMethod.

_npi_not_equal_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: First input to the function
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_onesMethod.

_npi_ones(shape, ctx, dtype)

Return a new array of given shape, type, and context, filled with ones.

Arguments

  • shape::Shape(tuple), optional, default=[]: The shape of the output
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._npi_padMethod.

_npi_pad(data, pad_width, mode, constant_value, reflect_type)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • pad_width::tuple of <Shape(tuple)>, required: Number of values padded to the edges of each axis. ((before1, after1), … (beforeN,afterN)) unique pad widths for each axis. ((before, after),) yields same before andafter pad for each axis. (pad,) or int is a shortcut for before = after = pad width for allaxes.
  • mode::{'constant', 'edge', 'maximum', 'minimum', 'reflect', 'symmetric'},optional, default='constant': Padding type to use. "constant" pads with constant_value "edge" pads using the edge values of the input array "reflect" Pads with the reflection of the vector mirroredon the first and last values of the vector along each axis. "symmetric" Pads with the reflection of the vector mirroredalong the edge of the array. "maximum" Pads with the maximum value of all or part of thevector along each axis. "minimum" Pads with the minimum value of all or part of thevector along each axis.
  • constant_value::double, optional, default=0: Used in ‘constant’. The values to set the padded values for each axis.((before1, after1), ... (beforeN, afterN)) unique pad constants foreach axis.((before, after),) yields same before and after constants for each axis.(constant,) or constant is a shortcut for before = after = constant for allaxes.Default is 0.
  • reflect_type::string, optional, default='even': Used in ‘reflect’, and ‘symmetric’. The ‘even’ style is the default with an unaltered reflection around the edge value. For the ‘odd’ style,the extended part of the array is created by subtracting the reflected values from two times the edge value.

source

# MXNet.mx._npi_paretoMethod.

_npi_pareto(input1, a, size, ctx)

Numpy behavior Pareto

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • a::float or None, optional, default=None:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._npi_percentileMethod.

_npi_percentile(a, q, axis, interpolation, keepdims, q_scalar)

Arguments

  • a::NDArray-or-SymbolicNode: Input data
  • q::NDArray-or-SymbolicNode: Input percentile
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • interpolation::{'higher', 'linear', 'lower', 'midpoint', 'nearest'},optional, default='linear': his optional parameter specifies the interpolation method to use when thedesired percentile lies between two data points i < j
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.
  • q_scalar::double or None, optional, default=None: inqut q is a scalar

source

# MXNet.mx._npi_pinvMethod.

_npi_pinv(A, rcond, hermitian)

Defined in src/operator/numpy/linalg/np_pinv.cc:L98

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of matrix
  • rcond::NDArray-or-SymbolicNode: Cutoff for small singular values.
  • hermitian::boolean, optional, default=0: If True, A is assumed to be Hermitian (symmetric if real-valued).

source

# MXNet.mx._npi_pinv_scalar_rcondMethod.

_npi_pinv_scalar_rcond(A, rcond, hermitian)

Defined in src/operator/numpy/linalg/np_pinv.cc:L176

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of matrix
  • rcond::double, optional, default=1.0000000000000001e-15: Cutoff for small singular values.
  • hermitian::boolean, optional, default=0: If True, A is assumed to be Hermitian (symmetric if real-valued).

source

# MXNet.mx._npi_polyvalMethod.

_npi_polyval(p, x)

Arguments

  • p::NDArray-or-SymbolicNode: polynomial coefficients
  • x::NDArray-or-SymbolicNode: variables

source

# MXNet.mx._npi_powerMethod.

_npi_power(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_power_scalarMethod.

_npi_power_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_powerdMethod.

_npi_powerd(input1, a, size)

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • a::float or None, optional, default=None:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.

source

# MXNet.mx._npi_radiansMethod.

_npi_radians(x)

Converts each element of the input array from degrees to radians. .. math:: radians([0, 90, 180, 270, 360]) = [0, \pi/2, \pi, 3\pi/2, 2\pi]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L379

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_random_randintMethod.

_npi_random_randint(low, high, shape, ctx, dtype)

npirandomrandint is an alias of _randomrandint.

Draw random samples from a discrete uniform distribution.

Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).

Example::

randint(low=0, high=5, shape=(2,2)) = [[ 0, 2], [ 3, 1]]

Defined in src/operator/random/sample_op.cc:L193

Arguments

  • low::long, required: Lower bound of the distribution.
  • high::long, required: Upper bound of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'int32', 'int64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to int32 if not defined (dtype=None).

source

# MXNet.mx._npi_rarctan2_scalarMethod.

_npi_rarctan2_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_rayleighMethod.

_npi_rayleigh(input1, scale, size, ctx)

Numpy behavior rayleigh

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • scale::float or None, optional, default=1:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._npi_rcopysign_scalarMethod.

_npi_rcopysign_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_reciprocalMethod.

_npi_reciprocal(x)

Return the reciprocal of the argument, element-wise. Example:: reciprocal([-2, 1, 3, 1.6, 0.2]) = [-0.5, 1.0, 0.33333334, 0.625, 5.0]

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_reshapeMethod.

_npi_reshape(a, newshape, reverse, order)

npireshape is an alias of npxreshape.

Defined in src/operator/numpy/npmatrixop.cc:L381

Arguments

  • a::NDArray-or-SymbolicNode: Array to be reshaped.
  • newshape::Shape(tuple), required: The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions. -2 to -6 are used for data manipulation. -2 copy this dimension from the input to the output shape. -3 will skip current dimension if and only if the current dim size is one. -4 copy all remain of the input dimensions to the output shape. -5 use the product of two consecutive dimensions of the input shape as the output. -6 split one dimension of the input into two dimensions passed subsequent to -6 in the new shape.
  • reverse::boolean, optional, default=0: If true then the special values are inferred from right to left
  • order::string, optional, default='C': Read the elements of a using this index order, and place the elements into the reshaped array using this index order. 'C' means to read/write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest. Note that currently only C-like order is supported

source

# MXNet.mx._npi_rintMethod.

_npi_rint(x)

Round elements of the array to the nearest integer. Example:: rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) = [-2., -2., -0., 0., 2., 2., 2.]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L156

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_rldexp_scalarMethod.

_npi_rldexp_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_rmod_scalarMethod.

_npi_rmod_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_rnn_param_concatMethod.

_npi_rnn_param_concat(data, num_args, dim)

npirnnparamconcat is an alias of rnnparam_concat.

Note: npirnnparamconcat takes variable number of positional inputs. So instead of calling as npirnnparamconcat([x, y, z], numargs=3), one should call via _npirnnparamconcat(x, y, z), and num_args will be determined automatically.

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx._npi_rot90Method.

_npi_rot90(data, k, axes)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • k::int, optional, default='1': Number of times the array is rotated by 90 degrees.
  • axes::Shape or None, optional, default=None: The array is rotated in the plane defined by the axes. Axes must be different.

source

# MXNet.mx._npi_rpower_scalarMethod.

_npi_rpower_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_rsubtract_scalarMethod.

_npi_rsubtract_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_rtrue_divide_scalarMethod.

_npi_rtrue_divide_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_scatter_set_ndMethod.

_npi_scatter_set_nd(lhs, rhs, indices, shape)

npiscattersetnd is an alias of scatterset_nd.

This operator has the same functionality as scatter_nd except that it does not reset the elements not indexed by the input index NDArray in the input data NDArray. output should be explicitly given and be the same as lhs.

.. note:: This operator is for internal use only.

Examples::

data = [2, 3, 0] indices = [[1, 1, 0], [0, 1, 0]] out = [[1, 1], [1, 1]] scatterset_nd(lhs=out, rhs=data, indices=indices, out=out) out = [[0, 1], [2, 3]]

Arguments

  • lhs::NDArray-or-SymbolicNode: source input
  • rhs::NDArray-or-SymbolicNode: value to assign
  • indices::NDArray-or-SymbolicNode: indices
  • shape::Shape(tuple), required: Shape of output.

source

# MXNet.mx._npi_share_memoryMethod.

_npi_share_memory(a, b)

Arguments

  • a::NDArray-or-SymbolicNode: First input
  • b::NDArray-or-SymbolicNode: Second input

source

# MXNet.mx._npi_shuffleMethod.

_npi_shuffle(data)

npishuffle is an alias of _shuffle.

Randomly shuffle the elements.

This shuffles the array along the first axis. The order of the elements in each subarray does not change. For example, if a 2D array is given, the order of the rows randomly changes, but the order of the elements in each row does not change.

Arguments

  • data::NDArray-or-SymbolicNode: Data to be shuffled.

source

# MXNet.mx._npi_signMethod.

_npi_sign(x)

Returns an element-wise indication of the sign of a number. The sign function returns -1 if x < 0, 0 if x==0, 1 if x > 0. Example:: sign([-2, 0, 3]) = [-1, 0, 1]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L148

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_sinMethod.

_npi_sin(x)

Trigonometric sine, element-wise. .. math:: sin([0, \pi/4, \pi/2]) = [0, 0.707, 1]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L320

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_sinhMethod.

_npi_sinh(x)

Returns the hyperbolic sine of the input array, computed element-wise. .. math:: sinh(x) = 0.5\times(exp(x) - exp(-x))

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L387

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_sliceMethod.

_npi_slice(data, begin, end, step)

npislice is an alias of slice.

Slices a region of the array. .. note:: $crop$ is deprecated. Use $slice$ instead. This function returns a sliced array between the indices given by begin and end with the corresponding step. For an input array of $shape=(d_0, d_1, ..., d_n-1)$, slice operation with $begin=(b_0, b_1...b_m-1)$, $end=(e_0, e_1, ..., e_m-1)$, and $step=(s_0, s_1, ..., s_m-1)$, where m <= n, results in an array with the shape $(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)$. The resulting array's k-th dimension contains elements from the k-th dimension of the input array starting from index $b_k$ (inclusive) with step $s_k$ until reaching $e_k$ (exclusive). If the k-th elements are None in the sequence of begin, end, and step, the following rule will be used to set default values. If s_k is None, set s_k=1. If s_k > 0, set b_k=0, e_k=d_k; else, set b_k=d_k-1, e_k=-1. The storage type of $slice$ output depends on storage types of inputs

  • slice(csr) = csr
  • otherwise, $slice$ generates output with default storage

.. note:: When input data storage type is csr, it only supports step=(), or step=(None,), or step=(1,) to generate a csr output. For other step parameter values, it falls back to slicing a dense tensor. Example:: x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] slice(x, begin=(0,1), end=(2,4)) = [[ 2., 3., 4.], [ 6., 7., 8.]] slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.], [5., 7.], [1., 3.]]

Defined in src/operator/tensor/matrix_op.cc:L481

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._npi_slice_assignMethod.

_npi_slice_assign(lhs, rhs, begin, end, step)

npisliceassign is an alias of _sliceassign.

Assign the rhs to a cropped subset of lhs.

Requirements

  • output should be explicitly given and be the same as lhs.
  • lhs and rhs are of the same data type, and on the same device.

From:src/operator/tensor/matrix_op.cc:514

Arguments

  • lhs::NDArray-or-SymbolicNode: Source input
  • rhs::NDArray-or-SymbolicNode: value to assign
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._npi_slice_assign_scalarMethod.

_npi_slice_assign_scalar(data, scalar, begin, end, step)

npisliceassignscalar is an alias of sliceassign_scalar.

(Assign the scalar to a cropped subset of the input.

Requirements

  • output should be explicitly given and be the same as input

)

From:src/operator/tensor/matrix_op.cc:540

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • scalar::double, optional, default=0: The scalar value for assignment.
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._npi_slogdetMethod.

_npi_slogdet(A)

npislogdet is an alias of linalgslogdet.

Compute the sign and log of the determinant of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

sign = sign(det(A)) logabsdet = log(abs(det(A)))

If n>2, slogdet is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only. .. note:: The gradient is not properly defined on sign, so the gradient of it is not backwarded. .. note:: No gradient is backwarded when A is non-invertible. Please see the docs of operator det for detail.

Examples::

Single matrix signed log determinant A = [[2., 3.], [1., 4.]] sign, logabsdet = slogdet(A) sign = [1.] logabsdet = [1.609438]

Batch matrix signed log determinant A = [[[2., 3.], [1., 4.]], [[1., 2.], [2., 4.]], [[1., 2.], [4., 3.]]] sign, logabsdet = slogdet(A) sign = [1., 0., -1.] logabsdet = [1.609438, -inf, 1.609438]

Defined in src/operator/tensor/la_op.cc:L1033

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx._npi_solveMethod.

_npi_solve(A, B)

Defined in src/operator/numpy/linalg/np_solve.cc:L88

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix
  • B::NDArray-or-SymbolicNode: Tensor of right side vector

source

# MXNet.mx._npi_sortMethod.

_npi_sort(data, axis, is_ascend)

npisort is an alias of sort.

Returns a sorted copy of an input array along the given axis.

Examples::

x = [[ 1, 4], [ 3, 1]]

// sorts along the last axis sort(x) = [[ 1., 4.], [ 1., 3.]]

// flattens and then sorts sort(x, axis=None) = [ 1., 1., 3., 4.]

// sorts along the first axis sort(x, axis=0) = [[ 1., 1.], [ 3., 4.]]

// in a descend order sort(x, is_ascend=0) = [[ 4., 1.], [ 3., 1.]]

Defined in src/operator/tensor/ordering_op.cc:L132

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • axis::int or None, optional, default='-1': Axis along which to choose sort the input tensor. If not given, the flattened array is used. Default is -1.
  • is_ascend::boolean, optional, default=1: Whether to sort in ascending or descending order.

source

# MXNet.mx._npi_splitMethod.

_npi_split(data, indices, axis, squeeze_axis, sections)

npisplit is an alias of splitv2.

Splits an array along a particular axis into multiple sub-arrays. Example:: x = [[[ 1.] [ 2.]] [[ 3.] [ 4.]] [[ 5.] [ 6.]]] x.shape = (3, 2, 1) y = splitv2(x, axis=1, indicesorsections=2) // a list of 2 arrays with shape (3, 1, 1) y = [[[ 1.]] [[ 3.]] [[ 5.]]] [[[ 2.]] [[ 4.]] [[ 6.]]] y[0].shape = (3, 1, 1) z = splitv2(x, axis=0, indicesorsections=3) // a list of 3 arrays with shape (1, 2, 1) z = [[[ 1.] [ 2.]]] [[[ 3.] [ 4.]]] [[[ 5.] [ 6.]]] z[0].shape = (1, 2, 1) w = splitv2(x, axis=0, indicesorsections=(1,)) // a list of 2 arrays with shape [(1, 2, 1), (2, 2, 1)] w = [[[ 1.] [ 2.]]] [[[3.] [4.]] [[5.] [6.]]] w[0].shape = (1, 2, 1) w[1].shape = (2, 2, 1) squeeze*axis=Trueremoves the axis with length 1 from the shapes of the output arrays. Note that settingsqueeze*axisto1removes axis with length 1 only along theaxiswhich it is split. Alsosqueeze*axiscan be set to true only ifinput.shape[axis] == indices_or_sections. Example:: z = splitv2(x, axis=0, indicesorsections=3, squeeze*axis=1) // a list of 3 arrays with shape (2, 1) z = [[ 1.] [ 2.]] [[ 3.] [ 4.]] [[ 5.] [ 6.]] z[0].shape = (2, 1)

Defined in src/operator/tensor/matrix_op.cc:L1087

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • indices::Shape(tuple), required: Indices of splits. The elements should denote the boundaries of at which split is performed along the axis.
  • axis::int, optional, default='1': Axis along which to split.
  • squeeze_axis::boolean, optional, default=0: If true, Removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $true$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to $true$ only if $input.shape[axis] == num_outputs$.
  • sections::int, optional, default='0': Number of sections if equally splitted. Default to 0 which means split by indices.

source

# MXNet.mx._npi_sqrtMethod.

_npi_sqrt(x)

Return the non-negative square-root of an array, element-wise. Example:: sqrt([4, 9, 16]) = [2, 3, 4]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L224

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_squareMethod.

_npi_square(x)

Return the element-wise square of the input. Example:: square([2, 3, 4]) = [4, 9, 16]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L216

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_stackMethod.

_npi_stack(data, axis, num_args)

Note: npistack takes variable number of positional inputs. So instead of calling as npistack([x, y, z], numargs=3), one should call via _npistack(x, y, z), and num_args will be determined automatically.

Join a sequence of arrays along a new axis.

The axis parameter specifies the index of the new axis in the dimensions of the result. For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last dimension.

Examples::

x = [1, 2] y = [3, 4]

stack(x, y) = [[1, 2], [3, 4]] stack(x, y, axis=1) = [[1, 3], [2, 4]]

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to stack
  • axis::int, optional, default='0': The axis in the result array along which the input arrays are stacked.
  • num_args::int, required: Number of inputs to be stacked.

source

# MXNet.mx._npi_stdMethod.

_npi_std(a, axis, dtype, ddof, keepdims)

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • dtype::{None, 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
  • ddof::int, optional, default='0': Starting value for the sum.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.

source

# MXNet.mx._npi_subtractMethod.

_npi_subtract(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._npi_subtract_scalarMethod.

_npi_subtract_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_svdMethod.

_npi_svd(A)

Defined in src/operator/numpy/linalg/np_gesvd.cc:L92

Arguments

  • A::NDArray-or-SymbolicNode: Input matrices to be factorized

source

# MXNet.mx._npi_swapaxesMethod.

_npi_swapaxes(data, dim1, dim2)

npiswapaxes is an alias of SwapAxis.

Interchanges two axes of an array.

Examples::

x = [[1, 2, 3]]) swapaxes(x, 0, 1) = [[ 1], [ 2], [ 3]]

x = [[[ 0, 1], [ 2, 3]], [[ 4, 5], [ 6, 7]]] // (2,2,2) array

swapaxes(x, 0, 2) = [[[ 0, 4], [ 2, 6]], [[ 1, 5], [ 3, 7]]]

Defined in src/operator/swapaxis.cc:L69

Arguments

  • data::NDArray-or-SymbolicNode: Input array.
  • dim1::int, optional, default='0': the first axis to be swapped.
  • dim2::int, optional, default='0': the second axis to be swapped.

source

# MXNet.mx._npi_takeMethod.

_npi_take(a, indices, axis, mode)

npitake is an alias of take.

Takes elements from an input array along the given axis.

This function slices the input array along a particular axis with the provided indices.

Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them in an output tensor of rank q + (r - 1).

Examples::

x = [4. 5. 6.]

// Trivial case, take the second element along the first axis.

take(x, [1]) = [ 5. ]

// The other trivial case, axis=-1, take the third element along the first axis

take(x, [3], axis=-1, mode='clip') = [ 6. ]

x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]]

// In this case we will get rows 0 and 1, then 1 and 2. Along axis 0

take(x, [[0,1],[1,2]]) = [[[ 1., 2.], [ 3., 4.]],

                        [[ 3.,  4.],
                         [ 5.,  6.]]]

// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around). // Along axis 1

take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1. 2.] [ 2. 1.]]

                                                  [[ 3.  4.]
                                                   [ 4.  3.]]

                                                  [[ 5.  6.]
                                                   [ 6.  5.]]]

The storage type of $take$ output depends upon the input storage type:

  • take(default, default) = default
  • take(csr, default, axis=0) = csr

Defined in src/operator/tensor/indexing_op.cc:L776

Arguments

  • a::NDArray-or-SymbolicNode: The input array.
  • indices::NDArray-or-SymbolicNode: The indices of the values to be extracted.
  • axis::int, optional, default='0': The axis of input array to be taken.For input tensor of rank r, it could be in the range of [-r, r-1]
  • mode::{'clip', 'raise', 'wrap'},optional, default='clip': Specify how out-of-bound indices bahave. Default is "clip". "clip" means clip to the range. So, if all indices mentioned are too large, they are replaced by the index that addresses the last element along an axis. "wrap" means to wrap around. "raise" means to raise an error when index out of range.

source

# MXNet.mx._npi_tanMethod.

_npi_tan(x)

Computes the element-wise tangent of the input array. .. math:: tan([0, \pi/4, \pi/2]) = [0, 1, -inf]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L336

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_tanhMethod.

_npi_tanh(x)

Returns the hyperbolic tangent of the input array, computed element-wise. .. math:: tanh(x) = sinh(x) / cosh(x)

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L403

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_tensordotMethod.

_npi_tensordot(a, b, a_axes_summed, b_axes_summed)

Arguments

  • a::NDArray-or-SymbolicNode: First input
  • b::NDArray-or-SymbolicNode: Second input
  • a_axes_summed::Shape(tuple), required:
  • b_axes_summed::Shape(tuple), required:

source

# MXNet.mx._npi_tensordot_int_axesMethod.

_npi_tensordot_int_axes(a, b, axes)

Arguments

  • a::NDArray-or-SymbolicNode: First input
  • b::NDArray-or-SymbolicNode: Second input
  • axes::int, required:

source

# MXNet.mx._npi_tensorinvMethod.

_npi_tensorinv(a, ind)

Defined in src/operator/numpy/linalg/np_tensorinv.cc:L101

Arguments

  • a::NDArray-or-SymbolicNode: First input
  • ind::int, optional, default='2': Number of first indices that are involved in the inverse sum.

source

# MXNet.mx._npi_tensorsolveMethod.

_npi_tensorsolve(a, b, a_axes)

Arguments

  • a::NDArray-or-SymbolicNode: First input
  • b::NDArray-or-SymbolicNode: Second input
  • a_axes::Shape(tuple), optional, default=[]: Tuple of ints, optional. Axes in a to reorder to the right, before inversion.

source

# MXNet.mx._npi_tileMethod.

_npi_tile(data, reps)

npitile is an alias of tile.

Repeats the whole array multiple times. If $reps$ has length d, and input array has dimension of n. There are three cases:

  • n=d. Repeat i-th dimension of the input by $reps[i]$ times:: x = [[1, 2], [3, 4]] tile(x, reps=(2,3)) = [[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]]
  • n>d. $reps$ is promoted to length n by pre-pending 1's to it. Thus for an input shape $(2,3)$, $repos=(2,)$ is treated as $(1,2)$:: tile(x, reps=(2,)) = [[ 1., 2., 1., 2.], [ 3., 4., 3., 4.]]
  • n<d. The input is promoted to be d-dimensional by prepending new axes. So a shape $(2,2)$ array is promoted to $(1,2,2)$ for 3-D replication:: tile(x, reps=(2,2,3)) = [[[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]], [[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]]]

Defined in src/operator/tensor/matrix_op.cc:L795

Arguments

  • data::NDArray-or-SymbolicNode: Input data array
  • reps::Shape(tuple), required: The number of times for repeating the tensor a. Each dim size of reps must be a positive integer. If reps has length d, the result will have dimension of max(d, a.ndim); If a.ndim < d, a is promoted to be d-dimensional by prepending new axes. If a.ndim > d, reps is promoted to a.ndim by pre-pending 1's to it.

source

# MXNet.mx._npi_trilMethod.

_npi_tril(data, k)

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • k::int, optional, default='0': Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal. If input has shape (S0 S1) k must be between -S0 and S1

source

# MXNet.mx._npi_true_divideMethod.

_npi_true_divide(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: Dividend array
  • rhs::NDArray-or-SymbolicNode: Divisor array

source

# MXNet.mx._npi_true_divide_scalarMethod.

_npi_true_divide_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._npi_truncMethod.

_npi_trunc(x)

Return the truncated value of the input, element-wise. The truncated value of the scalar x is the nearest integer i which is closer to zero than x is. In short, the fractional part of the signed number x is discarded. Example:: trunc([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) = [-1., -1., -0., 0., 1., 1., 2.]

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L198

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npi_uniformMethod.

_npi_uniform(input1, input2, low, high, size, ctx, dtype)

numpy behavior uniform

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • low::float or None, required:
  • high::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'float16', 'float32', 'float64'},optional, default='float32': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._npi_uniform_nMethod.

_npi_uniform_n(input1, input2, low, high, size, ctx, dtype)

numpy behavior uniform

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • input2::NDArray-or-SymbolicNode: Source input
  • low::float or None, required:
  • high::float or None, required:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'float16', 'float32', 'float64'},optional, default='float32': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._npi_uniqueMethod.

_npi_unique(data, return_index, return_inverse, return_counts, axis)

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • return_index::boolean, optional, default=0: If true, return the indices of the input.
  • return_inverse::boolean, optional, default=0: If true, return the indices of the input.
  • return_counts::boolean, optional, default=0: If true, return the number of times each unique item appears in input.
  • axis::int or None, optional, default='None': An integer that represents the axis to operator on.

source

# MXNet.mx._npi_varMethod.

_npi_var(a, axis, dtype, ddof, keepdims)

Arguments

  • a::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
  • dtype::{None, 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
  • ddof::int, optional, default='0': Starting value for the sum.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axes are left in the result as dimension with size one.

source

# MXNet.mx._npi_vstackMethod.

_npi_vstack(data, num_args)

Note: npivstack takes variable number of positional inputs. So instead of calling as npivstack([x, y, z], numargs=3), one should call via _npivstack(x, y, z), and num_args will be determined automatically.

Defined in src/operator/numpy/npmatrixop.cc:L1007

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to vstack
  • num_args::int, required: Number of inputs to be vstacked.

source

# MXNet.mx._npi_weibullMethod.

_npi_weibull(input1, a, size, ctx)

Numpy behavior Weibull

Arguments

  • input1::NDArray-or-SymbolicNode: Source input
  • a::float or None, optional, default=None:
  • size::Shape or None, optional, default=None: Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
  • ctx::string, optional, default='cpu': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.

source

# MXNet.mx._npi_whereMethod.

_npi_where(condition, x, y)

Arguments

  • condition::NDArray-or-SymbolicNode: condition array
  • x::NDArray-or-SymbolicNode: input x
  • y::NDArray-or-SymbolicNode: input y

source

# MXNet.mx._npi_where_lscalarMethod.

_npi_where_lscalar(condition, x, scalar)

Arguments

  • condition::NDArray-or-SymbolicNode: condition array
  • x::NDArray-or-SymbolicNode: input x
  • scalar::double, optional, default=0: The scalar value of x/y.

source

# MXNet.mx._npi_where_rscalarMethod.

_npi_where_rscalar(condition, y, scalar)

Arguments

  • condition::NDArray-or-SymbolicNode: condition array
  • y::NDArray-or-SymbolicNode: input y
  • scalar::double, optional, default=0: The scalar value of x/y.

source

# MXNet.mx._npi_where_scalar2Method.

_npi_where_scalar2(condition, x, y)

Arguments

  • condition::NDArray-or-SymbolicNode: condition array
  • x::double, optional, default=0: The scalar value of x.
  • y::double, optional, default=0: The scalar value of y.

source

# MXNet.mx._npi_zerosMethod.

_npi_zeros(shape, ctx, dtype)

Arguments

  • shape::Shape(tuple), optional, default=[]: The shape of the output
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Target data type.

source

# MXNet.mx._npx__image_adjust_lightingMethod.

_npx__image_adjust_lighting(data, alpha)

npx__imageadjustlighting is an alias of _imageadjust_lighting.

Adjust the lighting level of the input. Follow the AlexNet style.

Defined in src/operator/image/image_random.cc:L254

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • alpha::tuple of <float>, required: The lighting alphas for the R, G, B channels.

source

# MXNet.mx._npx__image_cropMethod.

_npx__image_crop(data, x, y, width, height)

npx__imagecrop is an alias of imagecrop.

Crop an image NDArray of shape (H x W x C) or (N x H x W x C) to the given size. Example: .. code-block:: python image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) mx.nd.image.crop(image, 1, 1, 2, 2) [[[144 34 4] [ 82 157 38]]

         [[156 111 230]
          [177  25  15]]]
        <NDArray 2x2x3 @cpu(0)>
    image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
    mx.nd.image.crop(image, 1, 1, 2, 2)            
        [[[[ 35 198  50]
           [242  94 168]]

          [[223 119 129]
           [249  14 154]]]


          [[[137 215 106]
            [ 79 174 133]]

           [[116 142 109]
            [ 35 239  50]]]]
        <NDArray 2x2x2x3 @cpu(0)>

Defined in src/operator/image/crop.cc:L65

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • x::int, required: Left boundary of the cropping area.
  • y::int, required: Top boundary of the cropping area.
  • width::int, required: Width of the cropping area.
  • height::int, required: Height of the cropping area.

source

# MXNet.mx._npx__image_flip_left_rightMethod.

_npx__image_flip_left_right(data)

npx__imageflipleftright is an alias of imageflipleftright.

Defined in src/operator/image/image_random.cc:L195

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._npx__image_flip_top_bottomMethod.

_npx__image_flip_top_bottom(data)

npx__imagefliptopbottom is an alias of imagefliptopbottom.

Defined in src/operator/image/image_random.cc:L205

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._npx__image_normalizeMethod.

_npx__image_normalize(data, mean, std)

npx__imagenormalize is an alias of imagenormalize.

Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and standard deviation.

Given mean `(m1, ..., mn)` and std `(s\ :sub:`1`\ , ..., s\ :sub:`n`)` for `n` channels,
this transform normalizes each channel of the input tensor with:

.. math::

    output[i] = (input[i] - m\ :sub:`i`\ ) / s\ :sub:`i`

If mean or std is scalar, the same value will be applied to all channels.

Default value for mean is 0.0 and stand deviation is 1.0.

Example:

.. code-block:: python
    image = mx.nd.random.uniform(0, 1, (3, 4, 2))
    normalize(image, mean=(0, 1, 2), std=(3, 2, 1))
        [[[ 0.18293785  0.19761486]
          [ 0.23839645  0.28142193]
          [ 0.20092112  0.28598186]
          [ 0.18162774  0.28241724]]
         [[-0.2881726  -0.18821815]
          [-0.17705294 -0.30780914]
          [-0.2812064  -0.3512327 ]
          [-0.05411351 -0.4716435 ]]
         [[-1.0363373  -1.7273437 ]
          [-1.6165586  -1.5223348 ]
          [-1.208275   -1.1878313 ]
          [-1.4711051  -1.5200229 ]]]
        <NDArray 3x4x2 @cpu(0)>

    image = mx.nd.random.uniform(0, 1, (2, 3, 4, 2))
    normalize(image, mean=(0, 1, 2), std=(3, 2, 1))
        [[[[ 0.18934818  0.13092826]
           [ 0.3085322   0.27869293]
           [ 0.02367868  0.11246539]
           [ 0.0290431   0.2160573 ]]
          [[-0.4898908  -0.31587923]
           [-0.08369008 -0.02142242]
           [-0.11092162 -0.42982462]
           [-0.06499392 -0.06495637]]
          [[-1.0213816  -1.526392  ]
           [-1.2008414  -1.1990893 ]
           [-1.5385206  -1.4795225 ]
           [-1.2194707  -1.3211205 ]]]
         [[[ 0.03942481  0.24021089]
           [ 0.21330701  0.1940066 ]
           [ 0.04778443  0.17912441]
           [ 0.31488964  0.25287187]]
          [[-0.23907584 -0.4470462 ]
           [-0.29266903 -0.2631998 ]
           [-0.3677222  -0.40683383]
           [-0.11288315 -0.13154092]]
          [[-1.5438497  -1.7834496 ]
           [-1.431566   -1.8647819 ]
           [-1.9812102  -1.675859  ]
           [-1.3823645  -1.8503251 ]]]]
        <NDArray 2x3x4x2 @cpu(0)>

Defined in src/operator/image/image_random.cc:L167

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • mean::tuple of <float>, optional, default=[0,0,0,0]: Sequence of means for each channel. Default value is 0.
  • std::tuple of <float>, optional, default=[1,1,1,1]: Sequence of standard deviations for each channel. Default value is 1.

source

# MXNet.mx._npx__image_random_brightnessMethod.

_npx__image_random_brightness(data, min_factor, max_factor)

npx__imagerandombrightness is an alias of _imagerandom_brightness.

Defined in src/operator/image/image_random.cc:L215

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._npx__image_random_color_jitterMethod.

_npx__image_random_color_jitter(data, brightness, contrast, saturation, hue)

npx__imagerandomcolorjitter is an alias of imagerandomcolorjitter.

Defined in src/operator/image/image_random.cc:L246

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • brightness::float, required: How much to jitter brightness.
  • contrast::float, required: How much to jitter contrast.
  • saturation::float, required: How much to jitter saturation.
  • hue::float, required: How much to jitter hue.

source

# MXNet.mx._npx__image_random_contrastMethod.

_npx__image_random_contrast(data, min_factor, max_factor)

npx__imagerandomcontrast is an alias of _imagerandom_contrast.

Defined in src/operator/image/image_random.cc:L222

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._npx__image_random_flip_left_rightMethod.

_npx__image_random_flip_left_right(data)

npx__imagerandomflipleftright is an alias of _imagerandomflipleft_right.

Defined in src/operator/image/image_random.cc:L200

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._npx__image_random_flip_top_bottomMethod.

_npx__image_random_flip_top_bottom(data)

npx__imagerandomfliptopbottom is an alias of _imagerandomfliptop_bottom.

Defined in src/operator/image/image_random.cc:L210

Arguments

  • data::NDArray-or-SymbolicNode: The input.

source

# MXNet.mx._npx__image_random_hueMethod.

_npx__image_random_hue(data, min_factor, max_factor)

npx__imagerandomhue is an alias of _imagerandom_hue.

Defined in src/operator/image/image_random.cc:L238

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._npx__image_random_lightingMethod.

_npx__image_random_lighting(data, alpha_std)

npx__imagerandomlighting is an alias of _imagerandom_lighting.

Randomly add PCA noise. Follow the AlexNet style.

Defined in src/operator/image/image_random.cc:L262

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • alpha_std::float, optional, default=0.0500000007: Level of the lighting noise.

source

# MXNet.mx._npx__image_random_saturationMethod.

_npx__image_random_saturation(data, min_factor, max_factor)

npx__imagerandomsaturation is an alias of _imagerandom_saturation.

Defined in src/operator/image/image_random.cc:L230

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • min_factor::float, required: Minimum factor.
  • max_factor::float, required: Maximum factor.

source

# MXNet.mx._npx__image_resizeMethod.

_npx__image_resize(data, size, keep_ratio, interp)

npx__imageresize is an alias of imageresize.

Resize an image NDArray of shape (H x W x C) or (N x H x W x C) to the given size Example: .. code-block:: python image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) mx.nd.image.resize(image, (3, 3)) [[[124 111 197] [158 80 155] [193 50 112]]

         [[110 100 113]
          [134 165 148]
          [157 231 182]]

         [[202 176 134]
          [174 191 149]
          [147 207 164]]]
        <NDArray 3x3x3 @cpu(0)>
    image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
    mx.nd.image.resize(image, (2, 2))            
        [[[[ 59 133  80]
           [187 114 153]]

          [[ 38 142  39]
           [207 131 124]]]


          [[[117 125 136]
           [191 166 150]]

          [[129  63 113]
           [182 109  48]]]]
        <NDArray 2x2x2x3 @cpu(0)>

Defined in src/operator/image/resize.cc:L70

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • size::Shape(tuple), optional, default=[]: Size of new image. Could be (width, height) or (size)
  • keep_ratio::boolean, optional, default=0: Whether to resize the short edge or both edges to size, if size is give as an integer.
  • interp::int, optional, default='1': Interpolation method for resizing. By default uses bilinear interpolationOptions are INTERNEAREST - a nearest-neighbor interpolationINTERLINEAR - a bilinear interpolationINTERAREA - resampling using pixel area relationINTERCUBIC - a bicubic interpolation over 4x4 pixel neighborhoodINTER_LANCZOS4 - a Lanczos interpolation over 8x8 pixel neighborhoodNote that the GPU version only support bilinear interpolation(1)

source

# MXNet.mx._npx__image_to_tensorMethod.

_npx__image_to_tensor(data)

npx__imagetotensor is an alias of _imageto_tensor.

Converts an image NDArray of shape (H x W x C) or (N x H x W x C) with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W) with values in the range [0, 1]

Example: .. code-block:: python image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) to_tensor(image) [[[ 0.85490197 0.72156864] [ 0.09019608 0.74117649] [ 0.61960787 0.92941177] [ 0.96470588 0.1882353 ]] [[ 0.6156863 0.73725492] [ 0.46666667 0.98039216] [ 0.44705883 0.45490196] [ 0.01960784 0.8509804 ]] [[ 0.39607844 0.03137255] [ 0.72156864 0.52941179] [ 0.16470589 0.7647059 ] [ 0.05490196 0.70588237]]]

    image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
    to_tensor(image)
        [[[[0.11764706 0.5803922 ]
           [0.9411765  0.10588235]
           [0.2627451  0.73333335]
           [0.5647059  0.32156864]]
          [[0.7176471  0.14117648]
           [0.75686276 0.4117647 ]
           [0.18431373 0.45490196]
           [0.13333334 0.6156863 ]]
          [[0.6392157  0.5372549 ]
           [0.52156866 0.47058824]
           [0.77254903 0.21568628]
           [0.01568628 0.14901961]]]
         [[[0.6117647  0.38431373]
           [0.6784314  0.6117647 ]
           [0.69411767 0.96862745]
           [0.67058825 0.35686275]]
          [[0.21960784 0.9411765 ]
           [0.44705883 0.43529412]
           [0.09803922 0.6666667 ]
           [0.16862746 0.1254902 ]]
          [[0.6156863  0.9019608 ]
           [0.35686275 0.9019608 ]
           [0.05882353 0.6509804 ]
           [0.20784314 0.7490196 ]]]]
        <NDArray 2x3x4x2 @cpu(0)>

Defined in src/operator/image/image_random.cc:L92

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray

source

# MXNet.mx._npx_activationMethod.

_npx_activation(data, act_type)

npxactivation is an alias of Activation.

Applies an activation function element-wise to the input.

The following activation functions are supported:

  • relu: Rectified Linear Unit, :math:y = max(x, 0)
  • sigmoid: :math:y = \frac{1}{1 + exp(-x)}
  • tanh: Hyperbolic tangent, :math:y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}
  • softrelu: Soft ReLU, or SoftPlus, :math:y = log(1 + exp(x))
  • softsign: :math:y = \frac{x}{1 + abs(x)}

Defined in src/operator/nn/activation.cc:L164

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • act_type::{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required: Activation function to be applied.

source

# MXNet.mx._npx_arange_likeMethod.

_npx_arange_like(data, start, step, repeat, ctx, axis)

npxarangelike is an alias of _contribarange_like.

Return an array with evenly spaced values. If axis is not given, the output will have the same shape as the input array. Otherwise, the output will be a 1-D array with size of the specified axis in input shape.

Examples::

x = [[0.14883883 0.7772398 0.94865847 0.7225052 ] [0.23729339 0.6112595 0.66538996 0.5132841 ] [0.30822644 0.9912457 0.15502319 0.7043658 ]]

out = mx.nd.contrib.arange_like(x, start=0)

[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]
 <NDArray 3x4 @cpu(0)>

out = mx.nd.contrib.arange_like(x, start=0, axis=-1)

[0. 1. 2. 3.]
<NDArray 4 @cpu(0)>

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • start::double, optional, default=0: Start of interval. The interval includes this value. The default start value is 0.
  • step::double, optional, default=1: Spacing between values.
  • repeat::int, optional, default='1': The repeating time of all elements. E.g repeat=3, the element a will be repeated three times –> a, a, a.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • axis::int or None, optional, default='None': Arange elements according to the size of a certain axis of input array. The negative numbers are interpreted counting from the backward. If not provided, will arange elements according to the input shape.

source

# MXNet.mx._npx_batch_dotMethod.

_npx_batch_dot(lhs, rhs, transpose_a, transpose_b, forward_stype)

npxbatchdot is an alias of batchdot.

Batchwise dot product.

$batch_dot$ is used to compute dot product of $x$ and $y$ when $x$ and $y$ are data in batch, namely N-D (N >= 3) arrays in shape of (B0, ..., B_i, :, :).

For example, given $x$ with shape (B_0, ..., B_i, N, M) and $y$ with shape (B_0, ..., B_i, M, K), the result array will have shape (B_0, ..., B_i, N, K), which is computed by::

batchdot(x,y)[b0, ..., bi, :, :] = dot(x[b0, ..., bi, :, :], y[b0, ..., b_i, :, :])

Defined in src/operator/tensor/dot.cc:L127

Arguments

  • lhs::NDArray-or-SymbolicNode: The first input
  • rhs::NDArray-or-SymbolicNode: The second input
  • transpose_a::boolean, optional, default=0: If true then transpose the first input before dot.
  • transpose_b::boolean, optional, default=0: If true then transpose the second input before dot.
  • forward_stype::{None, 'csr', 'default', 'row_sparse'},optional, default='None': The desired storage type of the forward output given by user, if thecombination of input storage types and this hint does not matchany implemented ones, the dot operator will perform fallback operationand still produce an output of the desired storage type.

source

# MXNet.mx._npx_batch_flattenMethod.

_npx_batch_flatten(data)

npxbatch_flatten is an alias of Flatten.

Flattens the input array into a 2-D array by collapsing the higher dimensions. .. note:: Flatten is deprecated. Use flatten instead. For an input array with shape $(d1, d2, ..., dk)$, flatten operation reshapes the input array into an output array of shape $(d1, d2...dk)$. Note that the behavior of this function is different from numpy.ndarray.flatten, which behaves similar to mxnet.ndarray.reshape((-1,)). Example:: x = [[ [1,2,3], [4,5,6], [7,8,9] ], [ [1,2,3], [4,5,6], [7,8,9] ]], flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]]

Defined in src/operator/tensor/matrix_op.cc:L249

Arguments

  • data::NDArray-or-SymbolicNode: Input array.

source

# MXNet.mx._npx_batch_normMethod.

_npx_batch_norm(data, gamma, beta, moving_mean, moving_var, eps, momentum, fix_gamma, use_global_stats, output_mean_var, axis, cudnn_off, min_calib_range, max_calib_range)

npxbatch_norm is an alias of BatchNorm.

Batch normalization.

Normalizes a data batch by mean and variance, and applies a scale $gamma$ as well as offset $beta$.

Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis:

.. math::

data_mean[i] = mean(data[:,i,:,...]) \ data_var[i] = var(data[:,i,:,...])

Then compute the normalized output, which has the same shape as input, as following:

.. math::

out[:,i,:,...] = \frac{data[:,i,:,...] - data_mean[i]}{\sqrt{data_var[i]+\epsilon}} * gamma[i] + beta[i]

Both mean and var returns a scalar by treating the input as a vector.

Assume the input has size k on axis 1, then both $gamma$ and $beta$ have shape (k,). If $output_mean_var$ is set to be true, then outputs both $data_mean$ and the inverse of $data_var$, which are needed for the backward pass. Note that gradient of these two outputs are blocked.

Besides the inputs and the outputs, this operator accepts two auxiliary states, $moving_mean$ and $moving_var$, which are k-length vectors. They are global statistics for the whole dataset, which are updated by::

movingmean = movingmean * momentum + datamean * (1 - momentum) movingvar = movingvar * momentum + datavar * (1 - momentum)

If $use_global_stats$ is set to be true, then $moving_mean$ and $moving_var$ are used instead of $data_mean$ and $data_var$ to compute the output. It is often used during inference.

The parameter $axis$ specifies which axis of the input shape denotes the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel axis to be the last item in the input shape.

Both $gamma$ and $beta$ are learnable parameters. But if $fix_gamma$ is true, then set $gamma$ to 1 and its gradient to 0.

.. Note:: When $fix_gamma$ is set to True, no sparse support is provided. If $fix_gamma is$ set to False, the sparse tensors will fallback.

Defined in src/operator/nn/batch_norm.cc:L608

Arguments

  • data::NDArray-or-SymbolicNode: Input data to batch normalization
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • moving_mean::NDArray-or-SymbolicNode: running mean of input
  • moving_var::NDArray-or-SymbolicNode: running variance of input
  • eps::double, optional, default=0.0010000000474974513: Epsilon to prevent div 0. Must be no less than CUDNNBNMIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5)
  • momentum::float, optional, default=0.899999976: Momentum for moving average
  • fix_gamma::boolean, optional, default=1: Fix gamma while training
  • use_global_stats::boolean, optional, default=0: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator.
  • output_mean_var::boolean, optional, default=0: Output the mean and inverse std
  • axis::int, optional, default='1': Specify which shape axis the channel is specified
  • cudnn_off::boolean, optional, default=0: Do not select CUDNN operator, if available
  • min_calib_range::float or None, optional, default=None: The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.
  • max_calib_range::float or None, optional, default=None: The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output.

source

# MXNet.mx._npx_castMethod.

_npx_cast(data, dtype)

npxcast is an alias of Cast.

Casts all elements of the input to a new type.

.. note:: $Cast$ is deprecated. Use $cast$ instead.

Example::

cast([0.9, 1.3], dtype='int32') = [0, 1] cast([1e20, 11.1], dtype='float16') = [inf, 11.09375] cast([300, 11.1, 10.9, -1, -3], dtype='uint8') = [44, 11, 10, 255, 253]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L664

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, required: Output data type.

source

# MXNet.mx._npx_constraint_checkMethod.

_npx_constraint_check(input, msg)

This operator will check if all the elements in a boolean tensor is true. If not, ValueError exception will be raised in the backend with given error message. In order to evaluate this operator, one should multiply the origin tensor by the return value of this operator to force this operator become part of the computation graph, otherwise the check would not be working under symoblic mode.

Example:

loc = np.zeros((2,2)) scale = np.array(#somevalue) constraint = (scale > 0) np.random.normal(loc, scale * npx.constraintcheck(constraint, 'Scale should be larger than zero'))

If elements in the scale tensor are all bigger than zero, npx.constraint_check would return np.array(True), which will not change the value of scale when multiplied by. If some of the elements in the scale tensor violate the constraint, i.e. there exists False in the boolean tensor constraint, a ValueError exception with given message 'Scale should be larger than zero' would be raised.

Defined in src/operator/numpy/npconstraintcheck.cc:L79

Arguments

  • input::NDArray-or-SymbolicNode: Input boolean array
  • msg::string, optional, default='Constraint violated.': Error message raised when constraint violated

source

# MXNet.mx._npx_convolutionMethod.

_npx_convolution(data, weight, bias, kernel, stride, dilate, pad, num_filter, num_group, workspace, no_bias, cudnn_tune, cudnn_off, layout)

npxconvolution is an alias of Convolution.

Compute N-D convolution on (N+2)-D input.

In the 2-D convolution, given input data with shape (batch_size, channel, height, width), the output is computed by

.. math::

out[n,i,:,:] = bias[i] + \sum_{j=0}^{channel} data[n,j,:,:] \star weight[i,j,:,:]

where :math:\star is the 2-D cross-correlation operator.

For general 2-D convolution, the shapes are

  • data: (batch_size, channel, height, width)
  • weight: (num_filter, channel, kernel[0], kernel[1])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, outheight, outwidth).

Define::

f(x,k,p,s,d) = floor((x+2p-d(k-1)-1)/s)+1

then we have::

outheight=f(height, kernel[0], pad[0], stride[0], dilate[0]) outwidth=f(width, kernel[1], pad[1], stride[1], dilate[1])

If $no_bias$ is set to be true, then the $bias$ term is ignored.

The default data $layout$ is NCHW, namely (batch_size, channel, height, width). We can choose other layouts such as NWC.

If $num_group$ is larger than 1, denoted by g, then split the input $data$ evenly into g parts along the channel axis, and also evenly split $weight$ along the first dimension. Next compute the convolution on the i-th part of the data with the i-th weight part. The output is obtained by concatenating all the g results.

1-D convolution does not have height dimension but only width in space.

  • data: (batch_size, channel, width)
  • weight: (num_filter, channel, kernel[0])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, out_width).

3-D convolution adds an additional depth dimension besides height and width. The shapes are

  • data: (batch_size, channel, depth, height, width)
  • weight: (num_filter, channel, kernel[0], kernel[1], kernel[2])
  • bias: (num_filter,)
  • out: (batchsize, numfilter, outdepth, outheight, out_width).

Both $weight$ and $bias$ are learnable parameters.

There are other options to tune the performance.

  • cudnn_tune: enable this option leads to higher startup time but may give faster speed. Options are

    • off: no tuning
    • limited_workspace:run test and pick the fastest algorithm that doesn't exceed workspace limit.
    • fastest: pick the fastest algorithm and ignore workspace limit.
    • None (default): the behavior is determined by environment variable $MXNET_CUDNN_AUTOTUNE_DEFAULT$. 0 for off, 1 for limited workspace (default), 2 for fastest.
    • workspace: A large number leads to more (GPU) memory usage but may improve the performance.

Defined in src/operator/nn/convolution.cc:L475

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the ConvolutionOp.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • kernel::Shape(tuple), required: Convolution kernel size: (w,), (h, w) or (d, h, w)
  • stride::Shape(tuple), optional, default=[]: Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.
  • num_filter::int (non-negative), required: Convolution filter(channel) number
  • num_group::int (non-negative), optional, default=1: Number of group partitions.
  • workspace::long (non-negative), optional, default=1024: Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages. When CUDNN is not used, it determines the effective batch size of the convolution kernel. When CUDNN is used, it controls the maximum temporary storage used for tuning the best CUDNN kernel when limited_workspace strategy is used.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • cudnn_tune::{None, 'fastest', 'limited_workspace', 'off'},optional, default='None': Whether to pick convolution algo by running performance test.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn for this layer.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None': Set layout for input, output and weight. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.NHWC and NDHWC are only supported on GPU.

source

# MXNet.mx._npx_deconvolutionMethod.

_npx_deconvolution(data, weight, bias, kernel, stride, dilate, pad, adj, target_shape, num_filter, num_group, workspace, no_bias, cudnn_tune, cudnn_off, layout)

npxdeconvolution is an alias of Deconvolution.

Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the input tensor. This operation can be seen as the gradient of Convolution operation with respect to its input. Convolution usually reduces the size of the input. Transposed convolution works the other way, going from a smaller input to a larger output while preserving the connectivity pattern.

Arguments

  • data::NDArray-or-SymbolicNode: Input tensor to the deconvolution operation.
  • weight::NDArray-or-SymbolicNode: Weights representing the kernel.
  • bias::NDArray-or-SymbolicNode: Bias added to the result after the deconvolution operation.
  • kernel::Shape(tuple), required: Deconvolution kernel size: (w,), (h, w) or (d, h, w). This is same as the kernel size used for the corresponding convolution
  • stride::Shape(tuple), optional, default=[]: The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: The amount of implicit zero padding added during convolution for each dimension of the input: (w,), (h, w) or (d, h, w). $(kernel-1)/2$ is usually a good choice. If target_shape is set, pad will be ignored and a padding that will generate the target shape will be used. Defaults to no padding.
  • adj::Shape(tuple), optional, default=[]: Adjustment for output shape: (w,), (h, w) or (d, h, w). If target_shape is set, adj will be ignored and computed accordingly.
  • target_shape::Shape(tuple), optional, default=[]: Shape of the output tensor: (w,), (h, w) or (d, h, w).
  • num_filter::int (non-negative), required: Number of output filters.
  • num_group::int (non-negative), optional, default=1: Number of groups partition.
  • workspace::long (non-negative), optional, default=512: Maximum temporary workspace allowed (MB) in deconvolution.This parameter has two usages. When CUDNN is not used, it determines the effective batch size of the deconvolution kernel. When CUDNN is used, it controls the maximum temporary storage used for tuning the best CUDNN kernel when limited_workspace strategy is used.
  • no_bias::boolean, optional, default=1: Whether to disable bias parameter.
  • cudnn_tune::{None, 'fastest', 'limited_workspace', 'off'},optional, default='None': Whether to pick convolution algorithm by running performance test.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn for this layer.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None': Set layout for input, output and weight. Empty for default layout, NCW for 1d, NCHW for 2d and NCDHW for 3d.NHWC and NDHWC are only supported on GPU.

source

# MXNet.mx._npx_dropoutMethod.

_npx_dropout(data, p, mode, axes, cudnn_off)

npxdropout is an alias of Dropout.

Applies dropout operation to input array.

  • During training, each element of the input is set to zero with probability p. The whole array is rescaled by :math:1/(1-p) to keep the expected sum of the input unchanged.
  • During testing, this operator does not change the input if mode is 'training'. If mode is 'always', the same computaion as during training will be applied.

Example::

random.seed(998) inputarray = array([[3., 0.5, -0.5, 2., 7.], [2., -0.4, 7., 3., 0.2]]) a = symbol.Variable('a') dropout = symbol.Dropout(a, p = 0.2) executor = dropout.simplebind(a = input_array.shape)

If training

executor.forward(istrain = True, a = inputarray) executor.outputs [[ 3.75 0.625 -0. 2.5 8.75 ] [ 2.5 -0.5 8.75 3.75 0. ]]

If testing

executor.forward(istrain = False, a = inputarray) executor.outputs [[ 3. 0.5 -0.5 2. 7. ] [ 2. -0.4 7. 3. 0.2 ]]

Defined in src/operator/nn/dropout.cc:L95

Arguments

  • data::NDArray-or-SymbolicNode: Input array to which dropout will be applied.
  • p::float, optional, default=0.5: Fraction of the input that gets dropped out during training time.
  • mode::{'always', 'training'},optional, default='training': Whether to only turn on dropout during training or to also turn on for inference.
  • axes::Shape(tuple), optional, default=[]: Axes for variational dropout kernel.
  • cudnn_off::boolean or None, optional, default=0: Whether to turn off cudnn in dropout operator. This option is ignored if axes is specified.

source

# MXNet.mx._npx_embeddingMethod.

_npx_embedding(data, weight, input_dim, output_dim, dtype, sparse_grad)

npxembedding is an alias of Embedding.

Maps integer indices to vector representations (embeddings).

This operator maps words to real-valued vectors in a high-dimensional space, called word embeddings. These embeddings can capture semantic and syntactic properties of the words. For example, it has been noted that in the learned embedding spaces, similar words tend to be close to each other and dissimilar words far apart.

For an input array of shape (d1, ..., dK), the shape of an output array is (d1, ..., dK, outputdim). All the input values should be integers in the range [0, inputdim).

If the inputdim is ip0 and outputdim is op0, then shape of the embedding weight matrix must be (ip0, op0).

When "sparsegrad" is False, if any index mentioned is too large, it is replaced by the index that addresses the last vector in an embedding matrix. When "sparsegrad" is True, an error will be raised if invalid indices are found.

Examples::

inputdim = 4 outputdim = 5

// Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3) y = [[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [ 10., 11., 12., 13., 14.], [ 15., 16., 17., 18., 19.]]

// Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)] x = [[ 1., 3.], [ 0., 2.]]

// Mapped input x to its vector representation y. Embedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.], [ 15., 16., 17., 18., 19.]],

                       [[  0.,   1.,   2.,   3.,   4.],
                        [ 10.,  11.,  12.,  13.,  14.]]]

The storage type of weight can be either row_sparse or default.

.. Note::

If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be
"row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
and Adam. Note that by default lazy updates is turned on, which may perform differently
from standard updates. For more details, please check the Optimization API at:
https://mxnet.incubator.apache.org/api/python/optimization/optimization.html

Defined in src/operator/tensor/indexing_op.cc:L597

Arguments

  • data::NDArray-or-SymbolicNode: The input array to the embedding operator.
  • weight::NDArray-or-SymbolicNode: The embedding weight matrix.
  • input_dim::int, required: Vocabulary size of the input indices.
  • output_dim::int, required: Dimension of the embedding vectors.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data type of weight.
  • sparse_grad::boolean, optional, default=0: Compute row sparse gradient in the backward calculation. If set to True, the grad's storage type is row_sparse.

source

# MXNet.mx._npx_erfMethod.

_npx_erf(data)

npxerf is an alias of erf.

Returns element-wise gauss error function of the input.

Example::

erf([0, -1., 10.]) = [0., -0.8427, 1.]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L886

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npx_erfinvMethod.

_npx_erfinv(data)

npxerfinv is an alias of erfinv.

Returns element-wise inverse gauss error function of the input.

Example::

erfinv([0, 0.5., -1.]) = [0., 0.4769, -inf]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L908

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npx_fully_connectedMethod.

_npx_fully_connected(data, weight, bias, num_hidden, no_bias, flatten)

npxfully_connected is an alias of FullyConnected.

Applies a linear transformation: :math:Y = XW^T + b.

If $flatten$ is set to be true, then the shapes are:

  • data: (batch_size, x1, x2, ..., xn)
  • weight: (num_hidden, x1 * x2 * ... * xn)
  • bias: (num_hidden,)
  • out: (batch_size, num_hidden)

If $flatten$ is set to be false, then the shapes are:

  • data: (x1, x2, ..., xn, input_dim)
  • weight: (num_hidden, input_dim)
  • bias: (num_hidden,)
  • out: (x1, x2, ..., xn, num_hidden)

The learnable parameters include both $weight$ and $bias$.

If $no_bias$ is set to be true, then the $bias$ term is ignored.

.. Note::

The sparse support for FullyConnected is limited to forward evaluation with `row_sparse`
weight and bias, where the length of `weight.indices` and `bias.indices` must be equal
to `num_hidden`. This could be useful for model inference with `row_sparse` weights
trained with importance sampling or noise contrastive estimation.

To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead
of sparse.FullyConnected.

Defined in src/operator/nn/fully_connected.cc:L286

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • num_hidden::int, required: Number of hidden nodes of the output.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • flatten::boolean, optional, default=1: Whether to collapse all but the first axis of the input data tensor.

source

# MXNet.mx._npx_gammaMethod.

_npx_gamma(data)

npxgamma is an alias of gamma.

Returns the gamma function (extension of the factorial function to the reals), computed element-wise on the input array.

The storage type of $gamma$ output is always dense

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npx_gammalnMethod.

_npx_gammaln(data)

npxgammaln is an alias of gammaln.

Returns element-wise log of the absolute value of the gamma function of the input.

The storage type of $gammaln$ output is always dense

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npx_gather_ndMethod.

_npx_gather_nd(data, indices)

npxgathernd is an alias of gathernd.

Gather elements or slices from data and store to a tensor whose shape is defined by indices.

Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with shape (M, Y_0, ..., Y_{K-1}), the output will have shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, output shape will simply be (Y_0, ..., Y_{K-1}).

The elements in output is defined as follows::

output[y0, ..., y, xM, ..., x] = data[indices[0, y0, ..., y], ..., indices[M-1, y0, ..., y], xM, ..., x]

Examples::

data = [[0, 1], [2, 3]] indices = [[1, 1, 0], [0, 1, 0]] gather_nd(data, indices) = [2, 3, 0]

data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] indices = [[0, 1], [1, 0]] gather_nd(data, indices) = [[3, 4], [5, 6]]

Arguments

  • data::NDArray-or-SymbolicNode: data
  • indices::NDArray-or-SymbolicNode: indices

source

# MXNet.mx._npx_intgemm_fully_connectedMethod.

_npx_intgemm_fully_connected(data, weight, scaling, bias, num_hidden, no_bias, flatten, out_type)

npxintgemmfullyconnected is an alias of contribintgemmfullyconnected.

Multiply matrices using 8-bit integers. data * weight.

Input tensor arguments are: data weight [scaling] [bias]

data: either float32 or prepared using intgemmpreparedata (in which case it is int8).

weight: must be prepared using intgemmprepareweight.

scaling: present if and only if outtype is float32. If so this is multiplied by the result before adding bias. Typically: scaling = (max passed to intgemmprepareweight)/127.0 if data is in float32 scaling = (maxpassed to intgemmpreparedata)/127.0 * (max passed to intgemmprepareweight)/127.0 if data is in int8

bias: present if and only if !no_bias. This is added to the output after scaling and has the same number of columns as the output.

out_type: type of the output.

Defined in src/operator/contrib/intgemm/intgemmfullyconnected_op.cc:L283

Arguments

  • data::NDArray-or-SymbolicNode: First argument to multiplication. Tensor of float32 (quantized on the fly) or int8 from intgemmpreparedata. If you use a different quantizer, be sure to ban -128. The last dimension must be a multiple of 64.
  • weight::NDArray-or-SymbolicNode: Second argument to multiplication. Tensor of int8 from intgemmprepareweight. The last dimension must be a multiple of 64. The product of non-last dimensions must be a multiple of 8.
  • scaling::NDArray-or-SymbolicNode: Scaling factor to apply if output type is float32.
  • bias::NDArray-or-SymbolicNode: Bias term.
  • num_hidden::int, required: Number of hidden nodes of the output.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • flatten::boolean, optional, default=1: Whether to collapse all but the first axis of the input data tensor.
  • out_type::{'float32', 'int32'},optional, default='float32': Output data type.

source

# MXNet.mx._npx_intgemm_maxabsoluteMethod.

_npx_intgemm_maxabsolute(data)

npxintgemmmaxabsolute is an alias of _contribintgemm_maxabsolute.

Compute the maximum absolute value in a tensor of float32 fast on a CPU. The tensor's total size must be a multiple of 16 and aligned to a multiple of 64 bytes. mxnet.nd.contrib.intgemm_maxabsolute(arr) == arr.abs().max()

Defined in src/operator/contrib/intgemm/maxabsoluteop.cc:L101

Arguments

  • data::NDArray-or-SymbolicNode: Tensor to compute maximum absolute value of

source

# MXNet.mx._npx_intgemm_prepare_dataMethod.

_npx_intgemm_prepare_data(data, maxabs)

npxintgemmpreparedata is an alias of contribintgemmpreparedata.

This operator converts quantizes float32 to int8 while also banning -128.

It it suitable for preparing an data matrix for use by intgemm's C=data * weights operation.

The float32 values are scaled such that maxabs maps to 127. Typically maxabs = maxabsolute(A).

Defined in src/operator/contrib/intgemm/preparedataop.cc:L112

Arguments

  • data::NDArray-or-SymbolicNode: Activation matrix to be prepared for multiplication.
  • maxabs::NDArray-or-SymbolicNode: Maximum absolute value to be used for scaling. (The values will be multiplied by 127.0 / maxabs.

source

# MXNet.mx._npx_intgemm_prepare_weightMethod.

_npx_intgemm_prepare_weight(weight, maxabs, already_quantized)

npxintgemmprepareweight is an alias of contribintgemmprepareweight.

This operator converts a weight matrix in column-major format to intgemm's internal fast representation of weight matrices. MXNet customarily stores weight matrices in column-major (transposed) format. This operator is not meant to be fast; it is meant to be run offline to quantize a model.

In other words, it prepares weight for the operation C = data * weight^T.

If the provided weight matrix is float32, it will be quantized first. The quantization function is (int8_t)(127.0 / max * weight) where multiplier is provided as argument 1 (the weight matrix is argument 0). Then the matrix will be rearranged into the CPU-dependent format.

If the provided weight matrix is already int8, the matrix will only be rearranged into the CPU-dependent format. This way one can quantize with intgemmpreparedata (which just quantizes), store to disk in a consistent format, then at load time convert to CPU-dependent format with intgemmprepareweight.

The internal representation depends on register length. So AVX512, AVX2, and SSSE3 have different formats. AVX512BW and AVX512VNNI have the same representation.

Defined in src/operator/contrib/intgemm/prepareweightop.cc:L153

Arguments

  • weight::NDArray-or-SymbolicNode: Parameter matrix to be prepared for multiplication.
  • maxabs::NDArray-or-SymbolicNode: Maximum absolute value for scaling. The weights will be multipled by 127.0 / maxabs.
  • already_quantized::boolean, optional, default=0: Is the weight matrix already quantized?

source

# MXNet.mx._npx_intgemm_take_weightMethod.

_npx_intgemm_take_weight(weight, indices)

npxintgemmtakeweight is an alias of contribintgemmtakeweight.

Index a weight matrix stored in intgemm's weight format. The indices select the outputs of matrix multiplication, not the inner dot product dimension.

Defined in src/operator/contrib/intgemm/takeweightop.cc:L128

Arguments

  • weight::NDArray-or-SymbolicNode: Tensor already in intgemm weight format to select from
  • indices::NDArray-or-SymbolicNode: indices to select on the 0th dimension of weight

source

# MXNet.mx._npx_layer_normMethod.

_npx_layer_norm(data, gamma, beta, axis, eps, output_mean_var)

npxlayer_norm is an alias of LayerNorm.

Layer normalization.

Normalizes the channels of the input tensor by mean and variance, and applies a scale $gamma$ as well as offset $beta$.

Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis and then compute the normalized output, which has the same shape as input, as following:

.. math::

out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta

Both $gamma$ and $beta$ are learnable parameters.

Unlike BatchNorm and InstanceNorm, the mean and var are computed along the channel dimension.

Assume the input has size k on axis 1, then both $gamma$ and $beta$ have shape (k,). If $output_mean_var$ is set to be true, then outputs both $data_mean$ and $data_std$. Note that no gradient will be passed through these two outputs.

The parameter $axis$ specifies which axis of the input shape denotes the 'channel' (separately normalized groups). The default is -1, which sets the channel axis to be the last item in the input shape.

Defined in src/operator/nn/layer_norm.cc:L201

Arguments

  • data::NDArray-or-SymbolicNode: Input data to layer normalization
  • gamma::NDArray-or-SymbolicNode: gamma array
  • beta::NDArray-or-SymbolicNode: beta array
  • axis::int, optional, default='-1': The axis to perform layer normalization. Usually, this should be be axis of the channel dimension. Negative values means indexing from right to left.
  • eps::float, optional, default=9.99999975e-06: An epsilon parameter to prevent division by 0.
  • output_mean_var::boolean, optional, default=0: Output the mean and std calculated along the given axis.

source

# MXNet.mx._npx_leaky_reluMethod.

_npx_leaky_relu(data, gamma, act_type, slope, lower_bound, upper_bound)

npxleaky_relu is an alias of LeakyReLU.

Applies Leaky rectified linear unit activation element-wise to the input.

Leaky ReLUs attempt to fix the "dying ReLU" problem by allowing a small slope when the input is negative and has a slope of one when input is positive.

The following modified ReLU Activation functions are supported:

  • elu: Exponential Linear Unit. y = x > 0 ? x : slope * (exp(x)-1)
  • selu: Scaled Exponential Linear Unit. y = lambda * (x > 0 ? x : alpha * (exp(x) - 1)) where lambda = 1.0507009873554804934193349852946 and alpha = 1.6732632423543772848170429916717.
  • leaky: Leaky ReLU. y = x > 0 ? x : slope * x
  • prelu: Parametric ReLU. This is same as leaky except that slope is learnt during training.
  • rrelu: Randomized ReLU. same as leaky but the slope is uniformly and randomly chosen from [lowerbound, upperbound) for training, while fixed to be (lowerbound+upperbound)/2 for inference.

Defined in src/operator/leaky_relu.cc:L162

Arguments

  • data::NDArray-or-SymbolicNode: Input data to activation function.
  • gamma::NDArray-or-SymbolicNode: Input data to activation function.
  • act_type::{'elu', 'gelu', 'leaky', 'prelu', 'rrelu', 'selu'},optional, default='leaky': Activation function to be applied.
  • slope::float, optional, default=0.25: Init slope for the activation. (For leaky and elu only)
  • lower_bound::float, optional, default=0.125: Lower bound of random slope. (For rrelu only)
  • upper_bound::float, optional, default=0.333999991: Upper bound of random slope. (For rrelu only)

source

# MXNet.mx._npx_log_softmaxMethod.

_npx_log_softmax(data, axis, temperature, dtype, use_length)

npxlogsoftmax is an alias of logsoftmax.

Computes the log softmax of the input. This is equivalent to computing softmax followed by log.

Examples::

x = mx.nd.array([1, 2, .1]) mx.nd.log_softmax(x).asnumpy()

array([-1.41702998, -0.41702995, -2.31702995], dtype=float32)

x = mx.nd.array( [[1, 2, .1],[.1, 2, 1]] ) mx.nd.log_softmax(x, axis=0).asnumpy()

array([[-0.34115392, -0.69314718, -1.24115396], [-1.24115396, -0.69314718, -0.34115392]], dtype=float32)

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • axis::int, optional, default='-1': The axis along which to compute softmax.
  • temperature::double or None, optional, default=None: Temperature parameter in softmax
  • dtype::{None, 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None).
  • use_length::boolean or None, optional, default=0: Whether to use the length input as a mask over the data input.

source

# MXNet.mx._npx_multibox_detectionMethod.

_npx_multibox_detection(cls_prob, loc_pred, anchor, clip, threshold, background_id, nms_threshold, force_suppress, variances, nms_topk)

npxmultiboxdetection is an alias of _contribMultiBoxDetection.

Convert multibox detection predictions.

Arguments

  • cls_prob::NDArray-or-SymbolicNode: Class probabilities.
  • loc_pred::NDArray-or-SymbolicNode: Location regression predictions.
  • anchor::NDArray-or-SymbolicNode: Multibox prior anchor boxes
  • clip::boolean, optional, default=1: Clip out-of-boundary boxes.
  • threshold::float, optional, default=0.00999999978: Threshold to be a positive prediction.
  • background_id::int, optional, default='0': Background id.
  • nms_threshold::float, optional, default=0.5: Non-maximum suppression threshold.
  • force_suppress::boolean, optional, default=0: Suppress all detections regardless of class_id.
  • variances::tuple of <float>, optional, default=[0.1,0.1,0.2,0.2]: Variances to be decoded from box regression output.
  • nms_topk::int, optional, default='-1': Keep maximum top k detections before nms, -1 for no limit.

source

# MXNet.mx._npx_multibox_priorMethod.

_npx_multibox_prior(data, sizes, ratios, clip, steps, offsets)

npxmultiboxprior is an alias of _contribMultiBoxPrior.

Generate prior(anchor) boxes from data, sizes and ratios.

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • sizes::tuple of <float>, optional, default=[1]: List of sizes of generated MultiBoxPriores.
  • ratios::tuple of <float>, optional, default=[1]: List of aspect ratios of generated MultiBoxPriores.
  • clip::boolean, optional, default=0: Whether to clip out-of-boundary boxes.
  • steps::tuple of <float>, optional, default=[-1,-1]: Priorbox step across y and x, -1 for auto calculation.
  • offsets::tuple of <float>, optional, default=[0.5,0.5]: Priorbox center offsets, y and x respectively

source

# MXNet.mx._npx_multibox_targetMethod.

_npx_multibox_target(anchor, label, cls_pred, overlap_threshold, ignore_label, negative_mining_ratio, negative_mining_thresh, minimum_negative_samples, variances)

npxmultiboxtarget is an alias of _contribMultiBoxTarget.

Compute Multibox training targets

Arguments

  • anchor::NDArray-or-SymbolicNode: Generated anchor boxes.
  • label::NDArray-or-SymbolicNode: Object detection labels.
  • cls_pred::NDArray-or-SymbolicNode: Class predictions.
  • overlap_threshold::float, optional, default=0.5: Anchor-GT overlap threshold to be regarded as a positive match.
  • ignore_label::float, optional, default=-1: Label for ignored anchors.
  • negative_mining_ratio::float, optional, default=-1: Max negative to positive samples ratio, use -1 to disable mining
  • negative_mining_thresh::float, optional, default=0.5: Threshold used for negative mining.
  • minimum_negative_samples::int, optional, default='0': Minimum number of negative samples.
  • variances::tuple of <float>, optional, default=[0.1,0.1,0.2,0.2]: Variances to be encoded in box regression target.

source

# MXNet.mx._npx_nonzeroMethod.

_npx_nonzero(x)

Arguments

  • x::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npx_one_hotMethod.

_npx_one_hot(indices, depth, on_value, off_value, dtype)

npxonehot is an alias of onehot.

Returns a one-hot array.

The locations represented by indices take value on_value, while all other locations take value off_value.

one_hot operation with indices of shape $(i0, i1)$ and depth of $d$ would result in an output array of shape $(i0, i1, d)$ with::

output[i,j,:] = offvalue output[i,j,indices[i,j]] = onvalue

Examples::

one_hot([1,0,2,0], 3) = [[ 0. 1. 0.] [ 1. 0. 0.] [ 0. 0. 1.] [ 1. 0. 0.]]

onehot([1,0,2,0], 3, onvalue=8, off_value=1, dtype='int32') = [[1 8 1] [8 1 1] [1 1 8] [8 1 1]]

one_hot([[1,0],[1,0],[2,0]], 3) = [[[ 0. 1. 0.] [ 1. 0. 0.]]

                                 [[ 0.  1.  0.]
                                  [ 1.  0.  0.]]

                                 [[ 0.  0.  1.]
                                  [ 1.  0.  0.]]]

Defined in src/operator/tensor/indexing_op.cc:L882

Arguments

  • indices::NDArray-or-SymbolicNode: array of locations where to set on_value
  • depth::int, required: Depth of the one hot dimension.
  • on_value::double, optional, default=1: The value assigned to the locations represented by indices.
  • off_value::double, optional, default=0: The value assigned to the locations not represented by indices.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': DType of the output

source

# MXNet.mx._npx_pickMethod.

_npx_pick(data, index, axis, keepdims, mode)

npxpick is an alias of pick.

Picks elements from an input array according to the input indices along the given axis.

Given an input array of shape $(d0, d1)$ and indices of shape $(i0,)$, the result will be an output array of shape $(i0,)$ with::

output[i] = input[i, indices[i]]

By default, if any index mentioned is too large, it is replaced by the index that addresses the last element along an axis (the clip mode).

This function supports n-dimensional input and (n-1)-dimensional indices arrays.

Examples::

x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]]

// picks elements with specified indices along axis 0 pick(x, y=[0,1], 0) = [ 1., 4.]

// picks elements with specified indices along axis 1 pick(x, y=[0,1,0], 1) = [ 1., 4., 5.]

// picks elements with specified indices along axis 1 using 'wrap' mode // to place indicies that would normally be out of bounds pick(x, y=[2,-1,-2], 1, mode='wrap') = [ 1., 4., 5.]

y = [[ 1.], [ 0.], [ 2.]]

// picks elements with specified indices along axis 1 and dims are maintained pick(x, y, 1, keepdims=True) = [[ 2.], [ 3.], [ 6.]]

Defined in src/operator/tensor/broadcastreduceop_index.cc:L150

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • index::NDArray-or-SymbolicNode: The index array
  • axis::int or None, optional, default='-1': int or None. The axis to picking the elements. Negative values means indexing from right to left. If is None, the elements in the index w.r.t the flattened input will be picked.
  • keepdims::boolean, optional, default=0: If true, the axis where we pick the elements is left in the result as dimension with size one.
  • mode::{'clip', 'wrap'},optional, default='clip': Specify how out-of-bound indices behave. Default is "clip". "clip" means clip to the range. So, if all indices mentioned are too large, they are replaced by the index that addresses the last element along an axis. "wrap" means to wrap around.

source

# MXNet.mx._npx_poolingMethod.

_npx_pooling(data, kernel, pool_type, global_pool, cudnn_off, pooling_convention, stride, pad, p_value, count_include_pad, layout)

npxpooling is an alias of Pooling.

Performs pooling on the input.

The shapes for 1-D pooling are

  • data and out: (batch_size, channel, width) (NCW layout) or (batch_size, width, channel) (NWC layout),

The shapes for 2-D pooling are

  • data and out: (batch_size, channel, height, width) (NCHW layout) or (batch_size, height, width, channel) (NHWC layout),

    outheight = f(height, kernel[0], pad[0], stride[0]) outwidth = f(width, kernel[1], pad[1], stride[1])

The definition of f depends on $pooling_convention$, which has two options:

  • valid (default)::

    f(x, k, p, s) = floor((x+2*p-k)/s)+1 * full, which is compatible with Caffe::

    f(x, k, p, s) = ceil((x+2*p-k)/s)+1

When $global_pool$ is set to be true, then global pooling is performed. It will reset $kernel=(height, width)$ and set the appropiate padding to 0.

Three pooling options are supported by $pool_type$:

  • avg: average pooling
  • max: max pooling
  • sum: sum pooling
  • lp: Lp pooling

For 3-D pooling, an additional depth dimension is added before height. Namely the input data and output will have shape (batch_size, channel, depth, height, width) (NCDHW layout) or (batch_size, depth, height, width, channel) (NDHWC layout).

Notes on Lp pooling:

Lp pooling was first introduced by this paper: https://arxiv.org/pdf/1204.3968.pdf. L-1 pooling is simply sum pooling, while L-inf pooling is simply max pooling. We can see that Lp pooling stands between those two, in practice the most common value for p is 2.

For each window $X$, the mathematical expression for Lp pooling is:

:math:f(X) = \sqrt[p]{\sum_{x}^{X} x^p}

Defined in src/operator/nn/pooling.cc:L416

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the pooling operator.
  • kernel::Shape(tuple), optional, default=[]: Pooling kernel size: (y, x) or (d, y, x)
  • pool_type::{'avg', 'lp', 'max', 'sum'},optional, default='max': Pooling type to be applied.
  • global_pool::boolean, optional, default=0: Ignore kernel size, do global pooling based on current input feature map.
  • cudnn_off::boolean, optional, default=0: Turn off cudnn pooling and use MXNet pooling operator.
  • pooling_convention::{'full', 'same', 'valid'},optional, default='valid': Pooling convention to be applied.
  • stride::Shape(tuple), optional, default=[]: Stride: for pooling (y, x) or (d, y, x). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: Pad for pooling: (y, x) or (d, y, x). Defaults to no padding.
  • p_value::int or None, optional, default='None': Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling.
  • count_include_pad::boolean or None, optional, default=None: Only used for AvgPool, specify whether to count padding elements for averagecalculation. For example, with a 55 kernel on a 33 corner of a image,the sum of the 9 valid elements will be divided by 25 if this is set to true,or it will be divided by 9 if this is set to false. Defaults to true.
  • layout::{None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC', 'NWC'},optional, default='None': Set layout for input and output. Empty for default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.

source

# MXNet.mx._npx_reluMethod.

_npx_relu(data)

Computes rectified linear activation. .. math:: max(features, 0)

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L34

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npx_reshapeMethod.

_npx_reshape(a, newshape, reverse, order)

Defined in src/operator/numpy/npmatrixop.cc:L381

Arguments

  • a::NDArray-or-SymbolicNode: Array to be reshaped.
  • newshape::Shape(tuple), required: The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions. -2 to -6 are used for data manipulation. -2 copy this dimension from the input to the output shape. -3 will skip current dimension if and only if the current dim size is one. -4 copy all remain of the input dimensions to the output shape. -5 use the product of two consecutive dimensions of the input shape as the output. -6 split one dimension of the input into two dimensions passed subsequent to -6 in the new shape.
  • reverse::boolean, optional, default=0: If true then the special values are inferred from right to left
  • order::string, optional, default='C': Read the elements of a using this index order, and place the elements into the reshaped array using this index order. 'C' means to read/write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest. Note that currently only C-like order is supported

source

# MXNet.mx._npx_reshape_likeMethod.

_npx_reshape_like(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end)

npxreshapelike is an alias of reshapelike.

Reshape some or all dimensions of lhs to have the same shape as some or all dimensions of rhs.

Returns a view of the lhs array with a new shape without altering any data.

Example::

x = [1, 2, 3, 4, 5, 6] y = [[0, -4], [3, 2], [2, 2]] reshape_like(x, y) = [[1, 2], [3, 4], [5, 6]]

More precise control over how dimensions are inherited is achieved by specifying slices over the lhs and rhs array dimensions. Only the sliced lhs dimensions are reshaped to the rhs sliced dimensions, with the non-sliced lhs dimensions staying the same.

Examples::

  • lhs shape = (30,7), rhs shape = (15,2,4), lhsbegin=0, lhsend=1, rhsbegin=0, rhsend=2, output shape = (15,2,7)
  • lhs shape = (3, 5), rhs shape = (1,15,4), lhsbegin=0, lhsend=2, rhsbegin=1, rhsend=2, output shape = (15)

Negative indices are supported, and None can be used for either lhs_end or rhs_end to indicate the end of the range.

Example::

  • lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhsbegin=-1, lhsend=None, rhsbegin=1, rhsend=None, output shape = (30, 2, 2, 3)

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L511

Arguments

  • lhs::NDArray-or-SymbolicNode: First input.
  • rhs::NDArray-or-SymbolicNode: Second input.
  • lhs_begin::int or None, optional, default='None': Defaults to 0. The beginning index along which the lhs dimensions are to be reshaped. Supports negative indices.
  • lhs_end::int or None, optional, default='None': Defaults to None. The ending index along which the lhs dimensions are to be used for reshaping. Supports negative indices.
  • rhs_begin::int or None, optional, default='None': Defaults to 0. The beginning index along which the rhs dimensions are to be used for reshaping. Supports negative indices.
  • rhs_end::int or None, optional, default='None': Defaults to None. The ending index along which the rhs dimensions are to be used for reshaping. Supports negative indices.

source

# MXNet.mx._npx_rnnMethod.

_npx_rnn(data, parameters, state, state_cell, sequence_length, state_size, num_layers, bidirectional, mode, p, state_outputs, projection_size, lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, use_sequence_length)

npxrnn is an alias of RNN.

Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.

When the input data is of type float32 and the environment variables MXNETCUDAALLOWTENSORCORE and MXNETCUDATENSOROPMATHALLOWCONVERSION are set to 1, this operator will try to use pseudo-float16 precision (float32 math with float16 I/O) precision in order to use Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.

Vanilla RNN

Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: ReLU and Tanh.

With ReLU activation function:

.. math:: ht = relu(W * xt + b + W{hh} * h + b_{hh})

With Tanh activtion function:

.. math:: ht = \tanh(W * xt + b + W{hh} * h + b_{hh})

Reference paper: Finding structure in time - Elman, 1988. https://crl.ucsd.edu/~elman/Papers/fsit.pdf

LSTM

Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf

.. math::

With the projection size being set, LSTM could use the projection feature to reduce the parameters size and give some speedups without significant damage to the accuracy.

Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128

.. math::

GRU

Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078

The definition of GRU here is slightly different from paper but compatible with CUDNN.

.. math::

Defined in src/operator/rnn.cc:L375

Arguments

  • data::NDArray-or-SymbolicNode: Input data to RNN
  • parameters::NDArray-or-SymbolicNode: Vector of all RNN trainable parameters concatenated
  • state::NDArray-or-SymbolicNode: initial hidden state of the RNN
  • state_cell::NDArray-or-SymbolicNode: initial cell state for LSTM networks (only for LSTM)
  • sequence_length::NDArray-or-SymbolicNode: Vector of valid sequence lengths for each element in batch. (Only used if usesequencelength kwarg is True)
  • state_size::int (non-negative), required: size of the state for each layer
  • num_layers::int (non-negative), required: number of stacked layers
  • bidirectional::boolean, optional, default=0: whether to use bidirectional recurrent layers
  • mode::{'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required: the type of RNN to compute
  • p::float, optional, default=0: drop rate of the dropout on the outputs of each RNN layer, except the last layer.
  • state_outputs::boolean, optional, default=0: Whether to have the states as symbol outputs.
  • projection_size::int or None, optional, default='None': size of project size
  • lstm_state_clip_min::double or None, optional, default=None: Minimum clip value of LSTM states. This option must be used together with lstmstateclip_max.
  • lstm_state_clip_max::double or None, optional, default=None: Maximum clip value of LSTM states. This option must be used together with lstmstateclip_min.
  • lstm_state_clip_nan::boolean, optional, default=0: Whether to stop NaN from propagating in state by clipping it to min/max. If clipping range is not specified, this option is ignored.
  • use_sequence_length::boolean, optional, default=0: If set to true, this layer takes in an extra input parameter sequence_length to specify variable length sequence

source

# MXNet.mx._npx_roi_poolingMethod.

_npx_roi_pooling(data, rois, pooled_size, spatial_scale)

npxroi_pooling is an alias of ROIPooling.

Performs region of interest(ROI) pooling on the input array.

ROI pooling is a variant of a max pooling layer, in which the output size is fixed and region of interest is a parameter. Its purpose is to perform max pooling on the inputs of non-uniform sizes to obtain fixed-size feature maps. ROI pooling is a neural-net layer mostly used in training a Fast R-CNN network for object detection.

This operator takes a 4D feature map as an input array and region proposals as rois, then it pools over sub-regions of input and produces a fixed-sized output array regardless of the ROI size.

To crop the feature map accordingly, you can resize the bounding box coordinates by changing the parameters rois and spatial_scale.

The cropped feature maps are pooled by standard max pooling operation to a fixed size output indicated by a pooled_size parameter. batch_size will change to the number of region bounding boxes after ROIPooling.

The size of each region of interest doesn't have to be perfectly divisible by the number of pooling sections(pooled_size).

Example::

x = [[[[ 0., 1., 2., 3., 4., 5.], [ 6., 7., 8., 9., 10., 11.], [ 12., 13., 14., 15., 16., 17.], [ 18., 19., 20., 21., 22., 23.], [ 24., 25., 26., 27., 28., 29.], [ 30., 31., 32., 33., 34., 35.], [ 36., 37., 38., 39., 40., 41.], [ 42., 43., 44., 45., 46., 47.]]]]

// region of interest i.e. bounding box coordinates. y = [[0,0,0,4,4]]

// returns array of shape (2,2) according to the given roi with max pooling. ROIPooling(x, y, (2,2), 1.0) = [[[[ 14., 16.], [ 26., 28.]]]]

// region of interest is changed due to the change in spacial_scale parameter. ROIPooling(x, y, (2,2), 0.7) = [[[[ 7., 9.], [ 19., 21.]]]]

Defined in src/operator/roi_pooling.cc:L224

Arguments

  • data::NDArray-or-SymbolicNode: The input array to the pooling operator, a 4D Feature maps
  • rois::NDArray-or-SymbolicNode: Bounding box coordinates, a 2D array of [[batch*index, x1, y1, x2, y2]], where (x1, y1) and (x2, y2) are top left and bottom right corners of designated region of interest. batch*index indicates the index of corresponding image in the input array
  • pooled_size::Shape(tuple), required: ROI pooling output shape (h,w)
  • spatial_scale::float, required: Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers

source

# MXNet.mx._npx_sequence_maskMethod.

_npx_sequence_mask(data, sequence_length, use_sequence_length, value, axis)

npxsequence_mask is an alias of SequenceMask.

Sets all elements outside the sequence to a constant value.

This function takes an n-dimensional input array of the form [maxsequencelength, batchsize, otherfeature_dims] and returns an array of the same shape.

Parameter sequence_length is used to handle variable-length sequences. sequence_length should be an input array of positive ints of dimension [batch*size]. To use this parameter, set use*sequence_lengthtoTrue, otherwise each example in the batch is assumed to have the max sequence length and this operator works as theidentity operator.

Example::

x = [[[ 1., 2., 3.], [ 4., 5., 6.]],

    [[  7.,   8.,   9.],
     [ 10.,  11.,  12.]],

    [[ 13.,  14.,   15.],
     [ 16.,  17.,   18.]]]

// Batch 1 B1 = [[ 1., 2., 3.], [ 7., 8., 9.], [ 13., 14., 15.]]

// Batch 2 B2 = [[ 4., 5., 6.], [ 10., 11., 12.], [ 16., 17., 18.]]

// works as identity operator when sequence_length parameter is not used SequenceMask(x) = [[[ 1., 2., 3.], [ 4., 5., 6.]],

                  [[  7.,   8.,   9.],
                   [ 10.,  11.,  12.]],

                  [[ 13.,  14.,   15.],
                   [ 16.,  17.,   18.]]]

// sequencelength [1,1] means 1 of each batch will be kept // and other rows are masked with default mask value = 0 SequenceMask(x, sequencelength=[1,1], usesequencelength=True) = [[[ 1., 2., 3.], [ 4., 5., 6.]],

             [[  0.,   0.,   0.],
              [  0.,   0.,   0.]],

             [[  0.,   0.,   0.],
              [  0.,   0.,   0.]]]

// sequencelength [2,3] means 2 of batch B1 and 3 of batch B2 will be kept // and other rows are masked with value = 1 SequenceMask(x, sequencelength=[2,3], usesequencelength=True, value=1) = [[[ 1., 2., 3.], [ 4., 5., 6.]],

             [[  7.,   8.,   9.],
              [  10.,  11.,  12.]],

             [[   1.,   1.,   1.],
              [  16.,  17.,  18.]]]

Defined in src/operator/sequence_mask.cc:L185

Arguments

  • data::NDArray-or-SymbolicNode: n-dimensional input array of the form [maxsequencelength, batchsize, otherfeature_dims] where n>2
  • sequence_length::NDArray-or-SymbolicNode: vector of sequence lengths of the form [batch_size]
  • use_sequence_length::boolean, optional, default=0: If set to true, this layer takes in an extra input parameter sequence_length to specify variable length sequence
  • value::float, optional, default=0: The value to be used as a mask.
  • axis::int, optional, default='0': The sequence axis. Only values of 0 and 1 are currently supported.

source

# MXNet.mx._npx_shape_arrayMethod.

_npx_shape_array(data)

npxshapearray is an alias of shapearray.

Returns a 1D int64 array containing the shape of data.

Example::

shape_array([[1,2,3,4], [5,6,7,8]]) = [2,4]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L573

Arguments

  • data::NDArray-or-SymbolicNode: Input Array.

source

# MXNet.mx._npx_sigmoidMethod.

_npx_sigmoid(data)

Computes sigmoid of x element-wise. .. math:: y = 1 / (1 + exp(-x))

Defined in src/operator/numpy/npelemwiseunaryopbasic.cc:L42

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._npx_sliceMethod.

_npx_slice(data, begin, end, step)

npxslice is an alias of slice.

Slices a region of the array. .. note:: $crop$ is deprecated. Use $slice$ instead. This function returns a sliced array between the indices given by begin and end with the corresponding step. For an input array of $shape=(d_0, d_1, ..., d_n-1)$, slice operation with $begin=(b_0, b_1...b_m-1)$, $end=(e_0, e_1, ..., e_m-1)$, and $step=(s_0, s_1, ..., s_m-1)$, where m <= n, results in an array with the shape $(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)$. The resulting array's k-th dimension contains elements from the k-th dimension of the input array starting from index $b_k$ (inclusive) with step $s_k$ until reaching $e_k$ (exclusive). If the k-th elements are None in the sequence of begin, end, and step, the following rule will be used to set default values. If s_k is None, set s_k=1. If s_k > 0, set b_k=0, e_k=d_k; else, set b_k=d_k-1, e_k=-1. The storage type of $slice$ output depends on storage types of inputs

  • slice(csr) = csr
  • otherwise, $slice$ generates output with default storage

.. note:: When input data storage type is csr, it only supports step=(), or step=(None,), or step=(1,) to generate a csr output. For other step parameter values, it falls back to slicing a dense tensor. Example:: x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] slice(x, begin=(0,1), end=(2,4)) = [[ 2., 3., 4.], [ 6., 7., 8.]] slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.], [5., 7.], [1., 3.]]

Defined in src/operator/tensor/matrix_op.cc:L481

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._npx_smooth_l1Method.

_npx_smooth_l1(data, scalar)

npxsmoothl1 is an alias of smoothl1.

Calculate Smooth L1 Loss(lhs, scalar) by summing

.. math::

f(x) =
\begin{cases}
(\sigma x)^2/2,& \text{if }x < 1/\sigma^2\\
|x|-0.5/\sigma^2,& \text{otherwise}
\end{cases}

where :math:x is an element of the tensor lhs and :math:\sigma is the scalar.

Example::

smoothl1([1, 2, 3, 4]) = [0.5, 1.5, 2.5, 3.5] smoothl1([1, 2, 3, 4], scalar=1) = [0.5, 1.5, 2.5, 3.5]

Defined in src/operator/tensor/elemwisebinaryscalaropextended.cc:L108

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::float: scalar input

source

# MXNet.mx._npx_softmaxMethod.

_npx_softmax(data, length, axis, temperature, dtype, use_length)

npxsoftmax is an alias of softmax.

Applies the softmax function.

The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1.

.. math:: softmax(\mathbf{z/t})j = \frac{e^{zj/t}}{\sum{k=1}^K e^{zk/t}}

for :math:j = 1, ..., K

t is the temperature parameter in softmax function. By default, t equals 1.0

Example::

x = [[ 1. 1. 1.] [ 1. 1. 1.]]

softmax(x,axis=0) = [[ 0.5 0.5 0.5] [ 0.5 0.5 0.5]]

softmax(x,axis=1) = [[ 0.33333334, 0.33333334, 0.33333334], [ 0.33333334, 0.33333334, 0.33333334]]

Defined in src/operator/nn/softmax.cc:L135

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • length::NDArray-or-SymbolicNode: The length array.
  • axis::int, optional, default='-1': The axis along which to compute softmax.
  • temperature::double or None, optional, default=None: Temperature parameter in softmax
  • dtype::{None, 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None).
  • use_length::boolean or None, optional, default=0: Whether to use the length input as a mask over the data input.

source

# MXNet.mx._npx_topkMethod.

_npx_topk(data, axis, k, ret_typ, is_ascend, dtype)

npxtopk is an alias of topk.

Returns the indices of the top k elements in an input array along the given axis (by default). If rettype is set to 'value' returns the value of top k elements (instead of indices). In case of rettype = 'both', both value and index would be returned. The returned elements will be sorted.

Examples::

x = [[ 0.3, 0.2, 0.4], [ 0.1, 0.3, 0.2]]

// returns an index of the largest element on last axis topk(x) = [[ 2.], [ 1.]]

// returns the value of top-2 largest elements on last axis topk(x, ret_typ='value', k=2) = [[ 0.4, 0.3], [ 0.3, 0.2]]

// returns the value of top-2 smallest elements on last axis topk(x, rettyp='value', k=2, isascend=1) = [[ 0.2 , 0.3], [ 0.1 , 0.2]]

// returns the value of top-2 largest elements on axis 0 topk(x, axis=0, ret_typ='value', k=2) = [[ 0.3, 0.3, 0.4], [ 0.1, 0.2, 0.2]]

// flattens and then returns list of both values and indices topk(x, ret_typ='both', k=2) = [[[ 0.4, 0.3], [ 0.3, 0.2]] , [[ 2., 0.], [ 1., 2.]]]

Defined in src/operator/tensor/ordering_op.cc:L67

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • axis::int or None, optional, default='-1': Axis along which to choose the top k indices. If not given, the flattened array is used. Default is -1.
  • k::int, optional, default='1': Number of top elements to select, should be always smaller than or equal to the element number in the given axis. A global sort is performed if set k < 1.
  • ret_typ::{'both', 'indices', 'mask', 'value'},optional, default='indices': The return type.

"value" means to return the top k values, "indices" means to return the indices of the top k values, "mask" means to return a mask array containing 0 and 1. 1 means the top k values. "both" means to return a list of both values and indices of top k elements.

  • is_ascend::boolean, optional, default=0: Whether to choose k largest or k smallest elements. Top K largest elements will be chosen if set to false.
  • dtype::{'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'},optional, default='float32': DType of the output indices when ret_typ is "indices" or "both". An error will be raised if the selected data type cannot precisely represent the indices.

source

# MXNet.mx._onehot_encodeMethod.

_onehot_encode(lhs, rhs)

Arguments

  • lhs::NDArray: Left operand to the function.
  • rhs::NDArray: Right operand to the function.

source

# MXNet.mx._plus_scalarMethod.

_plus_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._powerMethod.

_power(lhs, rhs)

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._power_scalarMethod.

_power_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._random_exponentialMethod.

_random_exponential(lam, shape, ctx, dtype)

Draw random samples from an exponential distribution.

Samples are distributed according to an exponential distribution parametrized by lambda (rate).

Example::

exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364], [ 0.04146638, 0.31715935]]

Defined in src/operator/random/sample_op.cc:L136

Arguments

  • lam::float, optional, default=1: Lambda parameter (rate) of the exponential distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._random_exponential_likeMethod.

_random_exponential_like(lam, data)

Draw random samples from an exponential distribution according to the input array shape.

Samples are distributed according to an exponential distribution parametrized by lambda (rate).

Example::

exponential(lam=4, data=ones(2,2)) = [[ 0.0097189 , 0.08999364], [ 0.04146638, 0.31715935]]

Defined in src/operator/random/sample_op.cc:L242

Arguments

  • lam::float, optional, default=1: Lambda parameter (rate) of the exponential distribution.
  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._random_gammaMethod.

_random_gamma(alpha, beta, shape, ctx, dtype)

Draw random samples from a gamma distribution.

Samples are distributed according to a gamma distribution parametrized by alpha (shape) and beta (scale).

Example::

gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289], [ 3.91697288, 3.65933681]]

Defined in src/operator/random/sample_op.cc:L124

Arguments

  • alpha::float, optional, default=1: Alpha parameter (shape) of the gamma distribution.
  • beta::float, optional, default=1: Beta parameter (scale) of the gamma distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._random_gamma_likeMethod.

_random_gamma_like(alpha, beta, data)

Draw random samples from a gamma distribution according to the input array shape.

Samples are distributed according to a gamma distribution parametrized by alpha (shape) and beta (scale).

Example::

gamma(alpha=9, beta=0.5, data=ones(2,2)) = [[ 7.10486984, 3.37695289], [ 3.91697288, 3.65933681]]

Defined in src/operator/random/sample_op.cc:L231

Arguments

  • alpha::float, optional, default=1: Alpha parameter (shape) of the gamma distribution.
  • beta::float, optional, default=1: Beta parameter (scale) of the gamma distribution.
  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._random_generalized_negative_binomialMethod.

_random_generalized_negative_binomial(mu, alpha, shape, ctx, dtype)

Draw random samples from a generalized negative binomial distribution.

Samples are distributed according to a generalized negative binomial distribution parametrized by mu (mean) and alpha (dispersion). alpha is defined as 1/k where k is the failure limit of the number of unsuccessful experiments (generalized to real numbers). Samples will always be returned as a floating point data type.

Example::

generalizednegativebinomial(mu=2.0, alpha=0.3, shape=(2,2)) = [[ 2., 1.], [ 6., 4.]]

Defined in src/operator/random/sample_op.cc:L178

Arguments

  • mu::float, optional, default=1: Mean of the negative binomial distribution.
  • alpha::float, optional, default=1: Alpha (dispersion) parameter of the negative binomial distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._random_generalized_negative_binomial_likeMethod.

_random_generalized_negative_binomial_like(mu, alpha, data)

Draw random samples from a generalized negative binomial distribution according to the input array shape.

Samples are distributed according to a generalized negative binomial distribution parametrized by mu (mean) and alpha (dispersion). alpha is defined as 1/k where k is the failure limit of the number of unsuccessful experiments (generalized to real numbers). Samples will always be returned as a floating point data type.

Example::

generalizednegativebinomial(mu=2.0, alpha=0.3, data=ones(2,2)) = [[ 2., 1.], [ 6., 4.]]

Defined in src/operator/random/sample_op.cc:L283

Arguments

  • mu::float, optional, default=1: Mean of the negative binomial distribution.
  • alpha::float, optional, default=1: Alpha (dispersion) parameter of the negative binomial distribution.
  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._random_negative_binomialMethod.

_random_negative_binomial(k, p, shape, ctx, dtype)

Draw random samples from a negative binomial distribution.

Samples are distributed according to a negative binomial distribution parametrized by k (limit of unsuccessful experiments) and p (failure probability in each experiment). Samples will always be returned as a floating point data type.

Example::

negative_binomial(k=3, p=0.4, shape=(2,2)) = [[ 4., 7.], [ 2., 5.]]

Defined in src/operator/random/sample_op.cc:L163

Arguments

  • k::int, optional, default='1': Limit of unsuccessful experiments.
  • p::float, optional, default=1: Failure probability in each experiment.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._random_negative_binomial_likeMethod.

_random_negative_binomial_like(k, p, data)

Draw random samples from a negative binomial distribution according to the input array shape.

Samples are distributed according to a negative binomial distribution parametrized by k (limit of unsuccessful experiments) and p (failure probability in each experiment). Samples will always be returned as a floating point data type.

Example::

negative_binomial(k=3, p=0.4, data=ones(2,2)) = [[ 4., 7.], [ 2., 5.]]

Defined in src/operator/random/sample_op.cc:L267

Arguments

  • k::int, optional, default='1': Limit of unsuccessful experiments.
  • p::float, optional, default=1: Failure probability in each experiment.
  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._random_normalMethod.

_random_normal(loc, scale, shape, ctx, dtype)

Draw random samples from a normal (Gaussian) distribution.

.. note:: The existing alias $normal$ is deprecated.

Samples are distributed according to a normal distribution parametrized by loc (mean) and scale (standard deviation).

Example::

normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], [-1.23474145, 1.55807114]]

Defined in src/operator/random/sample_op.cc:L112

Arguments

  • loc::float, optional, default=0: Mean of the distribution.
  • scale::float, optional, default=1: Standard deviation of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._random_normal_likeMethod.

_random_normal_like(loc, scale, data)

Draw random samples from a normal (Gaussian) distribution according to the input array shape.

Samples are distributed according to a normal distribution parametrized by loc (mean) and scale (standard deviation).

Example::

normal(loc=0, scale=1, data=ones(2,2)) = [[ 1.89171135, -1.16881478], [-1.23474145, 1.55807114]]

Defined in src/operator/random/sample_op.cc:L220

Arguments

  • loc::float, optional, default=0: Mean of the distribution.
  • scale::float, optional, default=1: Standard deviation of the distribution.
  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._random_pdf_dirichletMethod.

_random_pdf_dirichlet(sample, alpha, is_log)

Computes the value of the PDF of sample of Dirichlet distributions with parameter alpha.

The shape of alpha must match the leftmost subshape of sample. That is, sample can have the same shape as alpha, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the value of alpha at index i.

Examples::

random_pdf_dirichlet(sample=[[1,2],[2,3],[3,4]], alpha=[2.5, 2.5]) =
    [38.413498, 199.60245, 564.56085]

sample = [[[1, 2, 3], [10, 20, 30], [100, 200, 300]],
          [[0.1, 0.2, 0.3], [0.01, 0.02, 0.03], [0.001, 0.002, 0.003]]]

random_pdf_dirichlet(sample=sample, alpha=[0.1, 0.4, 0.9]) =
    [[2.3257459e-02, 5.8420084e-04, 1.4674458e-05],
     [9.2589635e-01, 3.6860607e+01, 1.4674468e+03]]

Defined in src/operator/random/pdf_op.cc:L315

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • alpha::NDArray-or-SymbolicNode: Concentration parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.

source

# MXNet.mx._random_pdf_exponentialMethod.

_random_pdf_exponential(sample, lam, is_log)

Computes the value of the PDF of sample of exponential distributions with parameters lam (rate).

The shape of lam must match the leftmost subshape of sample. That is, sample can have the same shape as lam, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the value of lam at index i.

Examples::

randompdfexponential(sample=[[1, 2, 3]], lam=[1]) = [[0.36787945, 0.13533528, 0.04978707]]

sample = [[1,2,3], [1,2,3], [1,2,3]]

randompdfexponential(sample=sample, lam=[1,0.5,0.25]) = [[0.36787945, 0.13533528, 0.04978707], [0.30326533, 0.18393973, 0.11156508], [0.1947002, 0.15163267, 0.11809164]]

Defined in src/operator/random/pdf_op.cc:L304

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.

source

# MXNet.mx._random_pdf_gammaMethod.

_random_pdf_gamma(sample, alpha, is_log, beta)

Computes the value of the PDF of sample of gamma distributions with parameters alpha (shape) and beta (rate).

alpha and beta must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as alpha and beta, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of alpha and beta at index i.

Examples::

randompdfgamma(sample=[[1,2,3,4,5]], alpha=[5], beta=[1]) = [[0.01532831, 0.09022352, 0.16803136, 0.19536681, 0.17546739]]

sample = [[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]]

randompdfgamma(sample=sample, alpha=[5,6,7], beta=[1,1,1]) = [[0.01532831, 0.09022352, 0.16803136, 0.19536681, 0.17546739], [0.03608941, 0.10081882, 0.15629345, 0.17546739, 0.16062315], [0.05040941, 0.10419563, 0.14622283, 0.16062315, 0.14900276]]

Defined in src/operator/random/pdf_op.cc:L302

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • alpha::NDArray-or-SymbolicNode: Alpha (shape) parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • beta::NDArray-or-SymbolicNode: Beta (scale) parameters of the distributions.

source

# MXNet.mx._random_pdf_generalized_negative_binomialMethod.

_random_pdf_generalized_negative_binomial(sample, mu, is_log, alpha)

Computes the value of the PDF of sample of generalized negative binomial distributions with parameters mu (mean) and alpha (dispersion). This can be understood as a reparameterization of the negative binomial, where k = 1 / alpha and p = 1 / (mu * alpha + 1).

mu and alpha must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as mu and alpha, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of mu and alpha at index i.

Examples::

random_pdf_generalized_negative_binomial(sample=[[1, 2, 3, 4]], alpha=[1], mu=[1]) =
    [[0.25, 0.125, 0.0625, 0.03125]]

sample = [[1,2,3,4],
          [1,2,3,4]]
random_pdf_generalized_negative_binomial(sample=sample, alpha=[1, 0.6666], mu=[1, 1.5]) =
    [[0.25,       0.125,      0.0625,     0.03125   ],
     [0.26517063, 0.16573331, 0.09667706, 0.05437994]]

Defined in src/operator/random/pdf_op.cc:L313

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • alpha::NDArray-or-SymbolicNode: Alpha (dispersion) parameters of the distributions.

source

# MXNet.mx._random_pdf_negative_binomialMethod.

_random_pdf_negative_binomial(sample, k, is_log, p)

Computes the value of the PDF of samples of negative binomial distributions with parameters k (failure limit) and p (failure probability).

k and p must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as k and p, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of k and p at index i.

Examples::

random_pdf_negative_binomial(sample=[[1,2,3,4]], k=[1], p=a[0.5]) =
    [[0.25, 0.125, 0.0625, 0.03125]]

# Note that k may be real-valued
sample = [[1,2,3,4],
          [1,2,3,4]]
random_pdf_negative_binomial(sample=sample, k=[1, 1.5], p=[0.5, 0.5]) =
    [[0.25,       0.125,      0.0625,     0.03125   ],
     [0.26516506, 0.16572815, 0.09667476, 0.05437956]]

Defined in src/operator/random/pdf_op.cc:L309

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • k::NDArray-or-SymbolicNode: Limits of unsuccessful experiments.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • p::NDArray-or-SymbolicNode: Failure probabilities in each experiment.

source

# MXNet.mx._random_pdf_normalMethod.

_random_pdf_normal(sample, mu, is_log, sigma)

Computes the value of the PDF of sample of normal distributions with parameters mu (mean) and sigma (standard deviation).

mu and sigma must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as mu and sigma, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of mu and sigma at index i.

Examples::

sample = [[-2, -1, 0, 1, 2]]
random_pdf_normal(sample=sample, mu=[0], sigma=[1]) =
    [[0.05399097, 0.24197073, 0.3989423, 0.24197073, 0.05399097]]

random_pdf_normal(sample=sample*2, mu=[0,0], sigma=[1,2]) =
    [[0.05399097, 0.24197073, 0.3989423,  0.24197073, 0.05399097],
     [0.12098537, 0.17603266, 0.19947115, 0.17603266, 0.12098537]]

Defined in src/operator/random/pdf_op.cc:L299

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • sigma::NDArray-or-SymbolicNode: Standard deviations of the distributions.

source

# MXNet.mx._random_pdf_poissonMethod.

_random_pdf_poisson(sample, lam, is_log)

Computes the value of the PDF of sample of Poisson distributions with parameters lam (rate).

The shape of lam must match the leftmost subshape of sample. That is, sample can have the same shape as lam, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the value of lam at index i.

Examples::

random_pdf_poisson(sample=[[0,1,2,3]], lam=[1]) =
    [[0.36787945, 0.36787945, 0.18393973, 0.06131324]]

sample = [[0,1,2,3],
          [0,1,2,3],
          [0,1,2,3]]

random_pdf_poisson(sample=sample, lam=[1,2,3]) =
    [[0.36787945, 0.36787945, 0.18393973, 0.06131324],
     [0.13533528, 0.27067056, 0.27067056, 0.18044704],
     [0.04978707, 0.14936121, 0.22404182, 0.22404182]]

Defined in src/operator/random/pdf_op.cc:L306

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.

source

# MXNet.mx._random_pdf_uniformMethod.

_random_pdf_uniform(sample, low, is_log, high)

Computes the value of the PDF of sample of uniform distributions on the intervals given by [low,high).

low and high must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as low and high, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of low and high at index i.

Examples::

random_pdf_uniform(sample=[[1,2,3,4]], low=[0], high=[10]) = [0.1, 0.1, 0.1, 0.1]

sample = [[[1, 2, 3],
           [1, 2, 3]],
          [[1, 2, 3],
           [1, 2, 3]]]
low  = [[0, 0],
        [0, 0]]
high = [[ 5, 10],
        [15, 20]]
random_pdf_uniform(sample=sample, low=low, high=high) =
    [[[0.2,        0.2,        0.2    ],
      [0.1,        0.1,        0.1    ]],
     [[0.06667,    0.06667,    0.06667],
      [0.05,       0.05,       0.05   ]]]

Defined in src/operator/random/pdf_op.cc:L297

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • low::NDArray-or-SymbolicNode: Lower bounds of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • high::NDArray-or-SymbolicNode: Upper bounds of the distributions.

source

# MXNet.mx._random_poissonMethod.

_random_poisson(lam, shape, ctx, dtype)

Draw random samples from a Poisson distribution.

Samples are distributed according to a Poisson distribution parametrized by lambda (rate). Samples will always be returned as a floating point data type.

Example::

poisson(lam=4, shape=(2,2)) = [[ 5., 2.], [ 4., 6.]]

Defined in src/operator/random/sample_op.cc:L149

Arguments

  • lam::float, optional, default=1: Lambda parameter (rate) of the Poisson distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._random_poisson_likeMethod.

_random_poisson_like(lam, data)

Draw random samples from a Poisson distribution according to the input array shape.

Samples are distributed according to a Poisson distribution parametrized by lambda (rate). Samples will always be returned as a floating point data type.

Example::

poisson(lam=4, data=ones(2,2)) = [[ 5., 2.], [ 4., 6.]]

Defined in src/operator/random/sample_op.cc:L254

Arguments

  • lam::float, optional, default=1: Lambda parameter (rate) of the Poisson distribution.
  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._random_randintMethod.

_random_randint(low, high, shape, ctx, dtype)

Draw random samples from a discrete uniform distribution.

Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).

Example::

randint(low=0, high=5, shape=(2,2)) = [[ 0, 2], [ 3, 1]]

Defined in src/operator/random/sample_op.cc:L193

Arguments

  • low::long, required: Lower bound of the distribution.
  • high::long, required: Upper bound of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'int32', 'int64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to int32 if not defined (dtype=None).

source

# MXNet.mx._random_uniformMethod.

_random_uniform(low, high, shape, ctx, dtype)

Draw random samples from a uniform distribution.

.. note:: The existing alias $uniform$ is deprecated.

Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).

Example::

uniform(low=0, high=1, shape=(2,2)) = [[ 0.60276335, 0.85794562], [ 0.54488319, 0.84725171]]

Defined in src/operator/random/sample_op.cc:L95

Arguments

  • low::float, optional, default=0: Lower bound of the distribution.
  • high::float, optional, default=1: Upper bound of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._random_uniform_likeMethod.

_random_uniform_like(low, high, data)

Draw random samples from a uniform distribution according to the input array shape.

Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).

Example::

uniform(low=0, high=1, data=ones(2,2)) = [[ 0.60276335, 0.85794562], [ 0.54488319, 0.84725171]]

Defined in src/operator/random/sample_op.cc:L208

Arguments

  • low::float, optional, default=0: Lower bound of the distribution.
  • high::float, optional, default=1: Upper bound of the distribution.
  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._ravel_multi_indexMethod.

_ravel_multi_index(data, shape)

Converts a batch of index arrays into an array of flat indices. The operator follows numpy conventions so a single multi index is given by a column of the input matrix. The leading dimension may be left unspecified by using -1 as placeholder.

Examples::

A = [[3,6,6],[4,5,1]] ravel(A, shape=(7,6)) = [22,41,37] ravel(A, shape=(-1,6)) = [22,41,37]

Defined in src/operator/tensor/ravel.cc:L41

Arguments

  • data::NDArray-or-SymbolicNode: Batch of multi-indices
  • shape::Shape(tuple), optional, default=None: Shape of the array into which the multi-indices apply.

source

# MXNet.mx._rdiv_scalarMethod.

_rdiv_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._rminus_scalarMethod.

_rminus_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._rnn_param_concatMethod.

_rnn_param_concat(data, num_args, dim)

Note: rnnparamconcat takes variable number of positional inputs. So instead of calling as _rnnparamconcat([x, y, z], numargs=3), one should call via rnnparamconcat(x, y, z), and numargs will be determined automatically.

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx._rpower_scalarMethod.

_rpower_scalar(data, scalar, is_int)

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._sample_exponentialMethod.

_sample_exponential(lam, shape, dtype)

Concurrent sampling from multiple exponential distributions with parameters lambda (rate).

The parameters of the distributions are provided as an input array. Let [s] be the shape of the input array, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input array, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input value at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input array.

Examples::

lam = [ 1.0, 8.5 ]

// Draw a single sample for each distribution sample_exponential(lam) = [ 0.51837951, 0.09994757]

// Draw a vector containing two samples for each distribution sample_exponential(lam, shape=(2)) = [[ 0.51837951, 0.19866663], [ 0.09994757, 0.50447971]]

Defined in src/operator/random/multisample_op.cc:L283

Arguments

  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._sample_gammaMethod.

_sample_gamma(alpha, shape, dtype, beta)

Concurrent sampling from multiple gamma distributions with parameters alpha (shape) and beta (scale).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Examples::

alpha = [ 0.0, 2.5 ] beta = [ 1.0, 0.7 ]

// Draw a single sample for each distribution sample_gamma(alpha, beta) = [ 0. , 2.25797319]

// Draw a vector containing two samples for each distribution sample_gamma(alpha, beta, shape=(2)) = [[ 0. , 0. ], [ 2.25797319, 1.70734084]]

Defined in src/operator/random/multisample_op.cc:L281

Arguments

  • alpha::NDArray-or-SymbolicNode: Alpha (shape) parameters of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • beta::NDArray-or-SymbolicNode: Beta (scale) parameters of the distributions.

source

# MXNet.mx._sample_generalized_negative_binomialMethod.

_sample_generalized_negative_binomial(mu, shape, dtype, alpha)

Concurrent sampling from multiple generalized negative binomial distributions with parameters mu (mean) and alpha (dispersion).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Samples will always be returned as a floating point data type.

Examples::

mu = [ 2.0, 2.5 ] alpha = [ 1.0, 0.1 ]

// Draw a single sample for each distribution samplegeneralizednegative_binomial(mu, alpha) = [ 0., 3.]

// Draw a vector containing two samples for each distribution samplegeneralizednegative_binomial(mu, alpha, shape=(2)) = [[ 0., 3.], [ 3., 1.]]

Defined in src/operator/random/multisample_op.cc:L292

Arguments

  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • alpha::NDArray-or-SymbolicNode: Alpha (dispersion) parameters of the distributions.

source

# MXNet.mx._sample_multinomialMethod.

_sample_multinomial(data, shape, get_prob, dtype)

Concurrent sampling from multiple multinomial distributions.

data is an n dimensional array whose last dimension has length k, where k is the number of possible outcomes of each multinomial distribution. This operator will draw shape samples from each distribution. If shape is empty one sample will be drawn from each distribution.

If get_prob is true, a second array containing log likelihood of the drawn samples will also be returned. This is usually used for reinforcement learning where you can provide reward as head gradient for this array to estimate gradient.

Note that the input distribution must be normalized, i.e. data must sum to 1 along its last axis.

Examples::

probs = [[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]]

// Draw a single sample for each distribution sample_multinomial(probs) = [3, 0]

// Draw a vector containing two samples for each distribution sample_multinomial(probs, shape=(2)) = [[4, 2], [0, 0]]

// requests log likelihood samplemultinomial(probs, getprob=True) = [2, 1], [0.2, 0.3]

Arguments

  • data::NDArray-or-SymbolicNode: Distribution probabilities. Must sum to one on the last axis.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • get_prob::boolean, optional, default=0: Whether to also return the log probability of sampled result. This is usually used for differentiating through stochastic variables, e.g. in reinforcement learning.
  • dtype::{'float16', 'float32', 'float64', 'int32', 'uint8'},optional, default='int32': DType of the output in case this can't be inferred.

source

# MXNet.mx._sample_negative_binomialMethod.

_sample_negative_binomial(k, shape, dtype, p)

Concurrent sampling from multiple negative binomial distributions with parameters k (failure limit) and p (failure probability).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Samples will always be returned as a floating point data type.

Examples::

k = [ 20, 49 ] p = [ 0.4 , 0.77 ]

// Draw a single sample for each distribution samplenegativebinomial(k, p) = [ 15., 16.]

// Draw a vector containing two samples for each distribution samplenegativebinomial(k, p, shape=(2)) = [[ 15., 50.], [ 16., 12.]]

Defined in src/operator/random/multisample_op.cc:L288

Arguments

  • k::NDArray-or-SymbolicNode: Limits of unsuccessful experiments.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • p::NDArray-or-SymbolicNode: Failure probabilities in each experiment.

source

# MXNet.mx._sample_normalMethod.

_sample_normal(mu, shape, dtype, sigma)

Concurrent sampling from multiple normal distributions with parameters mu (mean) and sigma (standard deviation).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Examples::

mu = [ 0.0, 2.5 ] sigma = [ 1.0, 3.7 ]

// Draw a single sample for each distribution sample_normal(mu, sigma) = [-0.56410581, 0.95934606]

// Draw a vector containing two samples for each distribution sample_normal(mu, sigma, shape=(2)) = [[-0.56410581, 0.2928229 ], [ 0.95934606, 4.48287058]]

Defined in src/operator/random/multisample_op.cc:L278

Arguments

  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • sigma::NDArray-or-SymbolicNode: Standard deviations of the distributions.

source

# MXNet.mx._sample_poissonMethod.

_sample_poisson(lam, shape, dtype)

Concurrent sampling from multiple Poisson distributions with parameters lambda (rate).

The parameters of the distributions are provided as an input array. Let [s] be the shape of the input array, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input array, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input value at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input array.

Samples will always be returned as a floating point data type.

Examples::

lam = [ 1.0, 8.5 ]

// Draw a single sample for each distribution sample_poisson(lam) = [ 0., 13.]

// Draw a vector containing two samples for each distribution sample_poisson(lam, shape=(2)) = [[ 0., 4.], [ 13., 8.]]

Defined in src/operator/random/multisample_op.cc:L285

Arguments

  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx._sample_uniformMethod.

_sample_uniform(low, shape, dtype, high)

Concurrent sampling from multiple uniform distributions on the intervals given by [low,high).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Examples::

low = [ 0.0, 2.5 ] high = [ 1.0, 3.7 ]

// Draw a single sample for each distribution sample_uniform(low, high) = [ 0.40451524, 3.18687344]

// Draw a vector containing two samples for each distribution sample_uniform(low, high, shape=(2)) = [[ 0.40451524, 0.18017688], [ 3.18687344, 3.68352246]]

Defined in src/operator/random/multisample_op.cc:L276

Arguments

  • low::NDArray-or-SymbolicNode: Lower bounds of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • high::NDArray-or-SymbolicNode: Upper bounds of the distributions.

source

# MXNet.mx._sample_unique_zipfianMethod.

_sample_unique_zipfian(range_max, shape)

Draw random samples from an an approximately log-uniform or Zipfian distribution without replacement.

This operation takes a 2-D shape (batch_size, num_sampled), and randomly generates num_sampled samples from the range of integers [0, range_max) for each instance in the batch.

The elements in each instance are drawn without replacement from the base distribution. The base distribution for this operator is an approximately log-uniform or Zipfian distribution:

P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)

Additionaly, it also returns the number of trials used to obtain num_sampled samples for each instance in the batch.

Example::

samples, trials = sampleunique_zipfian(750000, shape=(4, 8192)) unique(samples[0]) = 8192 unique(samples[3]) = 8192 trials[0] = 16435

Defined in src/operator/random/uniquesampleop.cc:L65

Arguments

  • range_max::int, required: The number of possible classes.
  • shape::Shape(tuple), optional, default=None: 2-D shape of the output, where shape[0] is the batch size, and shape[1] is the number of candidates to sample for each batch.

source

# MXNet.mx._scatter_elemwise_divMethod.

_scatter_elemwise_div(lhs, rhs)

Divides arguments element-wise. If the left-hand-side input is 'row_sparse', then only the values which exist in the left-hand sparse array are computed. The 'missing' values are ignored.

The storage type of $_scatter_elemwise_div$ output depends on storage types of inputs

  • scatterelemwisediv(rowsparse, rowsparse) = rowsparse
  • scatterelemwisediv(rowsparse, dense) = row_sparse
  • scatterelemwisediv(rowsparse, csr) = row_sparse
  • otherwise, $_scatter_elemwise_div$ behaves exactly like elemwise_div and generates output

with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._scatter_minus_scalarMethod.

_scatter_minus_scalar(data, scalar, is_int)

Subtracts a scalar to a tensor element-wise. If the left-hand-side input is 'row_sparse' or 'csr', then only the values which exist in the left-hand sparse array are computed. The 'missing' values are ignored.

The storage type of $_scatter_minus_scalar$ output depends on storage types of inputs

  • scatterminusscalar(rowsparse, scalar) = row_sparse
  • scatterminus_scalar(csr, scalar) = csr
  • otherwise, $_scatter_minus_scalar$ behaves exactly like minusscalar and generates output

with default storage

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._scatter_plus_scalarMethod.

_scatter_plus_scalar(data, scalar, is_int)

Adds a scalar to a tensor element-wise. If the left-hand-side input is 'row_sparse' or 'csr', then only the values which exist in the left-hand sparse array are computed. The 'missing' values are ignored.

The storage type of $_scatter_plus_scalar$ output depends on storage types of inputs

  • scatterplusscalar(rowsparse, scalar) = row_sparse
  • scatterplus_scalar(csr, scalar) = csr
  • otherwise, $_scatter_plus_scalar$ behaves exactly like plusscalar and generates output

with default storage

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::double, optional, default=1: Scalar input value
  • is_int::boolean, optional, default=1: Indicate whether scalar input is int type

source

# MXNet.mx._scatter_set_ndMethod.

_scatter_set_nd(lhs, rhs, indices, shape)

This operator has the same functionality as scatter_nd except that it does not reset the elements not indexed by the input index NDArray in the input data NDArray. output should be explicitly given and be the same as lhs.

.. note:: This operator is for internal use only.

Examples::

data = [2, 3, 0] indices = [[1, 1, 0], [0, 1, 0]] out = [[1, 1], [1, 1]] scatterset_nd(lhs=out, rhs=data, indices=indices, out=out) out = [[0, 1], [2, 3]]

Arguments

  • lhs::NDArray-or-SymbolicNode: source input
  • rhs::NDArray-or-SymbolicNode: value to assign
  • indices::NDArray-or-SymbolicNode: indices
  • shape::Shape(tuple), required: Shape of output.

source

# MXNet.mx._set_valueMethod.

_set_value(src)

Arguments

  • src::real_t: Source input to the function.

source

# MXNet.mx._shuffleMethod.

_shuffle(data)

Randomly shuffle the elements.

This shuffles the array along the first axis. The order of the elements in each subarray does not change. For example, if a 2D array is given, the order of the rows randomly changes, but the order of the elements in each row does not change.

Arguments

  • data::NDArray-or-SymbolicNode: Data to be shuffled.

source

# MXNet.mx._slice_assignMethod.

_slice_assign(lhs, rhs, begin, end, step)

Assign the rhs to a cropped subset of lhs.

Requirements

  • output should be explicitly given and be the same as lhs.
  • lhs and rhs are of the same data type, and on the same device.

From:src/operator/tensor/matrix_op.cc:514

Arguments

  • lhs::NDArray-or-SymbolicNode: Source input
  • rhs::NDArray-or-SymbolicNode: value to assign
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._slice_assign_scalarMethod.

_slice_assign_scalar(data, scalar, begin, end, step)

(Assign the scalar to a cropped subset of the input.

Requirements

  • output should be explicitly given and be the same as input

)

From:src/operator/tensor/matrix_op.cc:540

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • scalar::double, optional, default=0: The scalar value for assignment.
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._sparse_ElementWiseSumMethod.

_sparse_ElementWiseSum(args)

sparseElementWiseSum is an alias of add_n.

Note: sparseElementWiseSum takes variable number of positional inputs. So instead of calling as sparseElementWiseSum([x, y, z], numargs=3), one should call via _sparseElementWiseSum(x, y, z), and num_args will be determined automatically.

Adds all input arguments element-wise.

.. math:: add_n(a1, a2, ..., an) = a1 + a2 + ... + an

$add_n$ is potentially more efficient than calling $add$ by n times.

The storage type of $add_n$ output depends on storage types of inputs

  • addn(rowsparse, rowsparse, ..) = rowsparse
  • add_n(default, csr, default) = default
  • add_n(any input combinations longer than 4 (>4) with at least one default type) = default
  • otherwise, $add_n$ falls all inputs back to default storage and generates default storage

Defined in src/operator/tensor/elemwise_sum.cc:L155

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input arguments

source

# MXNet.mx._sparse_EmbeddingMethod.

_sparse_Embedding(data, weight, input_dim, output_dim, dtype, sparse_grad)

sparseEmbedding is an alias of Embedding.

Maps integer indices to vector representations (embeddings).

This operator maps words to real-valued vectors in a high-dimensional space, called word embeddings. These embeddings can capture semantic and syntactic properties of the words. For example, it has been noted that in the learned embedding spaces, similar words tend to be close to each other and dissimilar words far apart.

For an input array of shape (d1, ..., dK), the shape of an output array is (d1, ..., dK, outputdim). All the input values should be integers in the range [0, inputdim).

If the inputdim is ip0 and outputdim is op0, then shape of the embedding weight matrix must be (ip0, op0).

When "sparsegrad" is False, if any index mentioned is too large, it is replaced by the index that addresses the last vector in an embedding matrix. When "sparsegrad" is True, an error will be raised if invalid indices are found.

Examples::

inputdim = 4 outputdim = 5

// Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3) y = [[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [ 10., 11., 12., 13., 14.], [ 15., 16., 17., 18., 19.]]

// Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)] x = [[ 1., 3.], [ 0., 2.]]

// Mapped input x to its vector representation y. Embedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.], [ 15., 16., 17., 18., 19.]],

                       [[  0.,   1.,   2.,   3.,   4.],
                        [ 10.,  11.,  12.,  13.,  14.]]]

The storage type of weight can be either row_sparse or default.

.. Note::

If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be
"row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
and Adam. Note that by default lazy updates is turned on, which may perform differently
from standard updates. For more details, please check the Optimization API at:
https://mxnet.incubator.apache.org/api/python/optimization/optimization.html

Defined in src/operator/tensor/indexing_op.cc:L597

Arguments

  • data::NDArray-or-SymbolicNode: The input array to the embedding operator.
  • weight::NDArray-or-SymbolicNode: The embedding weight matrix.
  • input_dim::int, required: Vocabulary size of the input indices.
  • output_dim::int, required: Dimension of the embedding vectors.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': Data type of weight.
  • sparse_grad::boolean, optional, default=0: Compute row sparse gradient in the backward calculation. If set to True, the grad's storage type is row_sparse.

source

# MXNet.mx._sparse_FullyConnectedMethod.

_sparse_FullyConnected(data, weight, bias, num_hidden, no_bias, flatten)

sparseFullyConnected is an alias of FullyConnected.

Applies a linear transformation: :math:Y = XW^T + b.

If $flatten$ is set to be true, then the shapes are:

  • data: (batch_size, x1, x2, ..., xn)
  • weight: (num_hidden, x1 * x2 * ... * xn)
  • bias: (num_hidden,)
  • out: (batch_size, num_hidden)

If $flatten$ is set to be false, then the shapes are:

  • data: (x1, x2, ..., xn, input_dim)
  • weight: (num_hidden, input_dim)
  • bias: (num_hidden,)
  • out: (x1, x2, ..., xn, num_hidden)

The learnable parameters include both $weight$ and $bias$.

If $no_bias$ is set to be true, then the $bias$ term is ignored.

.. Note::

The sparse support for FullyConnected is limited to forward evaluation with `row_sparse`
weight and bias, where the length of `weight.indices` and `bias.indices` must be equal
to `num_hidden`. This could be useful for model inference with `row_sparse` weights
trained with importance sampling or noise contrastive estimation.

To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead
of sparse.FullyConnected.

Defined in src/operator/nn/fully_connected.cc:L286

Arguments

  • data::NDArray-or-SymbolicNode: Input data.
  • weight::NDArray-or-SymbolicNode: Weight matrix.
  • bias::NDArray-or-SymbolicNode: Bias parameter.
  • num_hidden::int, required: Number of hidden nodes of the output.
  • no_bias::boolean, optional, default=0: Whether to disable bias parameter.
  • flatten::boolean, optional, default=1: Whether to collapse all but the first axis of the input data tensor.

source

# MXNet.mx._sparse_LinearRegressionOutputMethod.

_sparse_LinearRegressionOutput(data, label, grad_scale)

sparseLinearRegressionOutput is an alias of LinearRegressionOutput.

Computes and optimizes for squared loss during backward propagation. Just outputs $data$ during forward propagation.

If :math:\hat{y}_i is the predicted value of the i-th sample, and :math:y_i is the corresponding target value, then the squared loss estimated over :math:n samples is defined as

:math:\text{SquaredLoss}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert \textbf{y}_i - \hat{\textbf{y}}_i \rVert_2

.. note:: Use the LinearRegressionOutput as the final output layer of a net.

The storage type of $label$ can be $default$ or $csr$

  • LinearRegressionOutput(default, default) = default
  • LinearRegressionOutput(default, csr) = default

By default, gradients of this loss function are scaled by factor 1/m, where m is the number of regression outputs of a training example. The parameter grad_scale can be used to change this scale to grad_scale/m.

Defined in src/operator/regression_output.cc:L92

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the function.
  • label::NDArray-or-SymbolicNode: Input label to the function.
  • grad_scale::float, optional, default=1: Scale the gradient by a float factor

source

# MXNet.mx._sparse_LogisticRegressionOutputMethod.

_sparse_LogisticRegressionOutput(data, label, grad_scale)

sparseLogisticRegressionOutput is an alias of LogisticRegressionOutput.

Applies a logistic function to the input.

The logistic function, also known as the sigmoid function, is computed as :math:\frac{1}{1+exp(-\textbf{x})}.

Commonly, the sigmoid is used to squash the real-valued output of a linear model :math:wTx+b into the [0,1] range so that it can be interpreted as a probability. It is suitable for binary classification or probability prediction tasks.

.. note:: Use the LogisticRegressionOutput as the final output layer of a net.

The storage type of $label$ can be $default$ or $csr$

  • LogisticRegressionOutput(default, default) = default
  • LogisticRegressionOutput(default, csr) = default

The loss function used is the Binary Cross Entropy Loss:

:math:-{(y\log(p) + (1 - y)\log(1 - p))}

Where y is the ground truth probability of positive outcome for a given example, and p the probability predicted by the model. By default, gradients of this loss function are scaled by factor 1/m, where m is the number of regression outputs of a training example. The parameter grad_scale can be used to change this scale to grad_scale/m.

Defined in src/operator/regression_output.cc:L152

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the function.
  • label::NDArray-or-SymbolicNode: Input label to the function.
  • grad_scale::float, optional, default=1: Scale the gradient by a float factor

source

# MXNet.mx._sparse_MAERegressionOutputMethod.

_sparse_MAERegressionOutput(data, label, grad_scale)

sparseMAERegressionOutput is an alias of MAERegressionOutput.

Computes mean absolute error of the input.

MAE is a risk metric corresponding to the expected value of the absolute error.

If :math:\hat{y}_i is the predicted value of the i-th sample, and :math:y_i is the corresponding target value, then the mean absolute error (MAE) estimated over :math:n samples is defined as

:math:\text{MAE}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert \textbf{y}_i - \hat{\textbf{y}}_i \rVert_1

.. note:: Use the MAERegressionOutput as the final output layer of a net.

The storage type of $label$ can be $default$ or $csr$

  • MAERegressionOutput(default, default) = default
  • MAERegressionOutput(default, csr) = default

By default, gradients of this loss function are scaled by factor 1/m, where m is the number of regression outputs of a training example. The parameter grad_scale can be used to change this scale to grad_scale/m.

Defined in src/operator/regression_output.cc:L120

Arguments

  • data::NDArray-or-SymbolicNode: Input data to the function.
  • label::NDArray-or-SymbolicNode: Input label to the function.
  • grad_scale::float, optional, default=1: Scale the gradient by a float factor

source

# MXNet.mx._sparse__contrib_round_steMethod.

_sparse__contrib_round_ste(data)

sparse__contribroundste is an alias of _contribround_ste.

Straight-through-estimator of round().

In forward pass, returns element-wise rounded value to the nearest integer of the input (same as round()).

In backward pass, returns gradients of $1$ everywhere (instead of $0$ everywhere as in round()): :math:\frac{d}{dx}{round\_ste(x)} = 1 vs. :math:\frac{d}{dx}{round(x)} = 0. This is useful for quantized training.

Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.

Example:: x = round_ste([-1.5, 1.5, -1.9, 1.9, 2.7]) x.backward() x = [-2., 2., -2., 2., 3.] x.grad() = [1., 1., 1., 1., 1.]

The storage type of $round_ste$ output depends upon the input storage type:

  • round_ste(default) = default
  • roundste(rowsparse) = row_sparse
  • round_ste(csr) = csr

Defined in src/operator/contrib/stes_op.cc:L54

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse__contrib_sign_steMethod.

_sparse__contrib_sign_ste(data)

sparse__contribsignste is an alias of _contribsign_ste.

Straight-through-estimator of sign().

In forward pass, returns element-wise sign of the input (same as sign()).

In backward pass, returns gradients of $1$ everywhere (instead of $0$ everywhere as in $sign()$): :math:\frac{d}{dx}{sign\_ste(x)} = 1 vs. :math:\frac{d}{dx}{sign(x)} = 0. This is useful for quantized training.

Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.

Example:: x = sign_ste([-2, 0, 3]) x.backward() x = [-1., 0., 1.] x.grad() = [1., 1., 1.]

The storage type of $sign_ste$ output depends upon the input storage type:

  • round_ste(default) = default
  • roundste(rowsparse) = row_sparse
  • round_ste(csr) = csr

Defined in src/operator/contrib/stes_op.cc:L79

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_absMethod.

_sparse_abs(data)

sparseabs is an alias of abs.

Returns element-wise absolute value of the input.

Example::

abs([-2, 0, 3]) = [2, 0, 3]

The storage type of $abs$ output depends upon the input storage type:

  • abs(default) = default
  • abs(rowsparse) = rowsparse
  • abs(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L720

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_adagrad_updateMethod.

_sparse_adagrad_update(weight, grad, history, lr, epsilon, wd, rescale_grad, clip_gradient)

Update function for AdaGrad optimizer.

Referenced from Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.

Updates are applied by::

rescaled_grad = clip(grad * rescale_grad, clip_gradient)
history = history + square(rescaled_grad)
w = w - learning_rate * rescaled_grad / sqrt(history + epsilon)

Note that non-zero values for the weight decay option are not supported.

Defined in src/operator/optimizer_op.cc:L908

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • history::NDArray-or-SymbolicNode: History
  • lr::float, required: Learning rate
  • epsilon::float, optional, default=1.00000001e-07: epsilon
  • wd::float, optional, default=0: weight decay
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx._sparse_adam_updateMethod.

_sparse_adam_update(weight, grad, mean, var, lr, beta1, beta2, epsilon, wd, rescale_grad, clip_gradient, lazy_update)

sparseadamupdate is an alias of adamupdate.

Update function for Adam optimizer. Adam is seen as a generalization of AdaGrad.

Adam update consists of the following steps, where g represents gradient and m, v are 1st and 2nd order moment estimates (mean and variance).

.. math::

gt = \nabla J(W)\ mt = \beta1 m{t-1} + (1 - \beta1) gt\ vt = \beta2 v + (1 - \beta2) gt^2\ Wt = W - \alpha \frac{ mt }{ \sqrt{ vt } + \epsilon }

It updates the weights using::

m = beta1m + (1-beta1)grad v = beta2v + (1-beta2)(grad**2) w += - learning_rate * m / (sqrt(v) + epsilon)

However, if grad's storage type is $row_sparse$, $lazy_update$ is True and the storage type of weight is the same as those of m and v, only the row slices whose indices appear in grad.indices are updated (for w, m and v)::

for row in grad.indices: m[row] = beta1m[row] + (1-beta1)grad[row] v[row] = beta2v[row] + (1-beta2)(grad[row]**2) w[row] += - learning_rate * m[row] / (sqrt(v[row]) + epsilon)

Defined in src/operator/optimizer_op.cc:L687

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mean::NDArray-or-SymbolicNode: Moving mean
  • var::NDArray-or-SymbolicNode: Moving variance
  • lr::float, required: Learning rate
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse and all of w, m and v have the same stype

source

# MXNet.mx._sparse_add_nMethod.

_sparse_add_n(args)

sparseaddn is an alias of addn.

Note: sparseaddn takes variable number of positional inputs. So instead of calling as _sparseaddn([x, y, z], numargs=3), one should call via sparseaddn(x, y, z), and numargs will be determined automatically.

Adds all input arguments element-wise.

.. math:: add_n(a1, a2, ..., an) = a1 + a2 + ... + an

$add_n$ is potentially more efficient than calling $add$ by n times.

The storage type of $add_n$ output depends on storage types of inputs

  • addn(rowsparse, rowsparse, ..) = rowsparse
  • add_n(default, csr, default) = default
  • add_n(any input combinations longer than 4 (>4) with at least one default type) = default
  • otherwise, $add_n$ falls all inputs back to default storage and generates default storage

Defined in src/operator/tensor/elemwise_sum.cc:L155

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input arguments

source

# MXNet.mx._sparse_arccosMethod.

_sparse_arccos(data)

sparsearccos is an alias of arccos.

Returns element-wise inverse cosine of the input array.

The input should be in range [-1, 1]. The output is in the closed interval :math:[0, \pi]

.. math:: arccos([-1, -.707, 0, .707, 1]) = [\pi, 3\pi/4, \pi/2, \pi/4, 0]

The storage type of $arccos$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L233

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_arccoshMethod.

_sparse_arccosh(data)

sparsearccosh is an alias of arccosh.

Returns the element-wise inverse hyperbolic cosine of the input array, computed element-wise.

The storage type of $arccosh$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L535

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_arcsinMethod.

_sparse_arcsin(data)

sparsearcsin is an alias of arcsin.

Returns element-wise inverse sine of the input array.

The input should be in the range [-1, 1]. The output is in the closed interval of [:math:-\pi/2, :math:\pi/2].

.. math:: arcsin([-1, -.707, 0, .707, 1]) = [-\pi/2, -\pi/4, 0, \pi/4, \pi/2]

The storage type of $arcsin$ output depends upon the input storage type:

  • arcsin(default) = default
  • arcsin(rowsparse) = rowsparse
  • arcsin(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L187

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_arcsinhMethod.

_sparse_arcsinh(data)

sparsearcsinh is an alias of arcsinh.

Returns the element-wise inverse hyperbolic sine of the input array, computed element-wise.

The storage type of $arcsinh$ output depends upon the input storage type:

  • arcsinh(default) = default
  • arcsinh(rowsparse) = rowsparse
  • arcsinh(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L494

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_arctanMethod.

_sparse_arctan(data)

sparsearctan is an alias of arctan.

Returns element-wise inverse tangent of the input array.

The output is in the closed interval :math:[-\pi/2, \pi/2]

.. math:: arctan([-1, 0, 1]) = [-\pi/4, 0, \pi/4]

The storage type of $arctan$ output depends upon the input storage type:

  • arctan(default) = default
  • arctan(rowsparse) = rowsparse
  • arctan(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L282

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_arctanhMethod.

_sparse_arctanh(data)

sparsearctanh is an alias of arctanh.

Returns the element-wise inverse hyperbolic tangent of the input array, computed element-wise.

The storage type of $arctanh$ output depends upon the input storage type:

  • arctanh(default) = default
  • arctanh(rowsparse) = rowsparse
  • arctanh(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L579

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_broadcast_addMethod.

_sparse_broadcast_add(lhs, rhs)

sparsebroadcastadd is an alias of broadcastadd.

Returns element-wise sum of the input arrays with broadcasting.

broadcast_plus is an alias to the function broadcast_add.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcast_add(x, y) = [[ 1., 1., 1.], [ 2., 2., 2.]]

broadcast_plus(x, y) = [[ 1., 1., 1.], [ 2., 2., 2.]]

Supported sparse operations:

broadcastadd(csr, dense(1D)) = dense broadcastadd(dense(1D), csr) = dense

Defined in src/operator/tensor/elemwisebinarybroadcastopbasic.cc:L57

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._sparse_broadcast_divMethod.

_sparse_broadcast_div(lhs, rhs)

sparsebroadcastdiv is an alias of broadcastdiv.

Returns element-wise division of the input arrays with broadcasting.

Example::

x = [[ 6., 6., 6.], [ 6., 6., 6.]]

y = [[ 2.], [ 3.]]

broadcast_div(x, y) = [[ 3., 3., 3.], [ 2., 2., 2.]]

Supported sparse operations:

broadcast_div(csr, dense(1D)) = csr

Defined in src/operator/tensor/elemwisebinarybroadcastopbasic.cc:L186

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._sparse_broadcast_minusMethod.

_sparse_broadcast_minus(lhs, rhs)

sparsebroadcastminus is an alias of broadcastsub.

Returns element-wise difference of the input arrays with broadcasting.

broadcast_minus is an alias to the function broadcast_sub.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcast_sub(x, y) = [[ 1., 1., 1.], [ 0., 0., 0.]]

broadcast_minus(x, y) = [[ 1., 1., 1.], [ 0., 0., 0.]]

Supported sparse operations:

broadcastsub/minus(csr, dense(1D)) = dense broadcastsub/minus(dense(1D), csr) = dense

Defined in src/operator/tensor/elemwisebinarybroadcastopbasic.cc:L105

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._sparse_broadcast_mulMethod.

_sparse_broadcast_mul(lhs, rhs)

sparsebroadcastmul is an alias of broadcastmul.

Returns element-wise product of the input arrays with broadcasting.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcast_mul(x, y) = [[ 0., 0., 0.], [ 1., 1., 1.]]

Supported sparse operations:

broadcast_mul(csr, dense(1D)) = csr

Defined in src/operator/tensor/elemwisebinarybroadcastopbasic.cc:L145

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._sparse_broadcast_plusMethod.

_sparse_broadcast_plus(lhs, rhs)

sparsebroadcastplus is an alias of broadcastadd.

Returns element-wise sum of the input arrays with broadcasting.

broadcast_plus is an alias to the function broadcast_add.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcast_add(x, y) = [[ 1., 1., 1.], [ 2., 2., 2.]]

broadcast_plus(x, y) = [[ 1., 1., 1.], [ 2., 2., 2.]]

Supported sparse operations:

broadcastadd(csr, dense(1D)) = dense broadcastadd(dense(1D), csr) = dense

Defined in src/operator/tensor/elemwisebinarybroadcastopbasic.cc:L57

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._sparse_broadcast_subMethod.

_sparse_broadcast_sub(lhs, rhs)

sparsebroadcastsub is an alias of broadcastsub.

Returns element-wise difference of the input arrays with broadcasting.

broadcast_minus is an alias to the function broadcast_sub.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcast_sub(x, y) = [[ 1., 1., 1.], [ 0., 0., 0.]]

broadcast_minus(x, y) = [[ 1., 1., 1.], [ 0., 0., 0.]]

Supported sparse operations:

broadcastsub/minus(csr, dense(1D)) = dense broadcastsub/minus(dense(1D), csr) = dense

Defined in src/operator/tensor/elemwisebinarybroadcastopbasic.cc:L105

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx._sparse_cast_storageMethod.

_sparse_cast_storage(data, stype)

sparsecaststorage is an alias of caststorage.

Casts tensor storage type to the new type.

When an NDArray with default storage type is cast to csr or row_sparse storage, the result is compact, which means:

  • for csr, zero values will not be retained
  • for row_sparse, row slices of all zeros will not be retained

The storage type of $cast_storage$ output depends on stype parameter:

  • cast_storage(csr, 'default') = default
  • caststorage(rowsparse, 'default') = default
  • cast_storage(default, 'csr') = csr
  • caststorage(default, 'rowsparse') = row_sparse
  • cast_storage(csr, 'csr') = csr
  • caststorage(rowsparse, 'rowsparse') = rowsparse

Example::

dense = [[ 0.,  1.,  0.],
         [ 2.,  0.,  3.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]]

# cast to row_sparse storage type
rsp = cast_storage(dense, 'row_sparse')
rsp.indices = [0, 1]
rsp.values = [[ 0.,  1.,  0.],
              [ 2.,  0.,  3.]]

# cast to csr storage type
csr = cast_storage(dense, 'csr')
csr.indices = [1, 0, 2]
csr.values = [ 1.,  2.,  3.]
csr.indptr = [0, 1, 3, 3, 3]

Defined in src/operator/tensor/cast_storage.cc:L71

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • stype::{'csr', 'default', 'row_sparse'}, required: Output storage type.

source

# MXNet.mx._sparse_cbrtMethod.

_sparse_cbrt(data)

sparsecbrt is an alias of cbrt.

Returns element-wise cube-root value of the input.

.. math:: cbrt(x) = \sqrt[3]{x}

Example::

cbrt([1, 8, -125]) = [1, 2, -5]

The storage type of $cbrt$ output depends upon the input storage type:

  • cbrt(default) = default
  • cbrt(rowsparse) = rowsparse
  • cbrt(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L270

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_ceilMethod.

_sparse_ceil(data)

sparseceil is an alias of ceil.

Returns element-wise ceiling of the input.

The ceil of the scalar x is the smallest integer i, such that i >= x.

Example::

ceil([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-2., -1., 2., 2., 3.]

The storage type of $ceil$ output depends upon the input storage type:

  • ceil(default) = default
  • ceil(rowsparse) = rowsparse
  • ceil(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L817

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_clipMethod.

_sparse_clip(data, a_min, a_max)

sparseclip is an alias of clip.

Clips (limits) the values in an array. Given an interval, values outside the interval are clipped to the interval edges. Clipping $x$ between a_min and a_max would be:: .. math:: clip(x, amin, amax) = \max(\min(x, amax), amin)) Example:: x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] clip(x,1,8) = [ 1., 1., 2., 3., 4., 5., 6., 7., 8., 8.] The storage type of $clip$ output depends on storage types of inputs and the amin, amax parameter values:

  • clip(default) = default
  • clip(rowsparse, amin <= 0, amax >= 0) = rowsparse
  • clip(csr, amin <= 0, amax >= 0) = csr
  • clip(rowsparse, amin < 0, a_max < 0) = default
  • clip(rowsparse, amin > 0, a_max > 0) = default
  • clip(csr, amin < 0, amax < 0) = csr
  • clip(csr, amin > 0, amax > 0) = csr

Defined in src/operator/tensor/matrix_op.cc:L676

Arguments

  • data::NDArray-or-SymbolicNode: Input array.
  • a_min::float, required: Minimum value
  • a_max::float, required: Maximum value

source

# MXNet.mx._sparse_concatMethod.

_sparse_concat(data, num_args, dim)

sparseconcat is an alias of Concat.

Note: sparseconcat takes variable number of positional inputs. So instead of calling as sparseconcat([x, y, z], numargs=3), one should call via _sparseconcat(x, y, z), and num_args will be determined automatically.

Joins input arrays along a given axis.

.. note:: Concat is deprecated. Use concat instead.

The dimensions of the input arrays should be the same except the axis along which they will be concatenated. The dimension of the output array along the concatenated axis will be equal to the sum of the corresponding dimensions of the input arrays.

The storage type of $concat$ output depends on storage types of inputs

  • concat(csr, csr, ..., csr, dim=0) = csr
  • otherwise, $concat$ generates output with default storage

Example::

x = [[1,1],[2,2]] y = [[3,3],[4,4],[5,5]] z = [[6,6], [7,7],[8,8]]

concat(x,y,z,dim=0) = [[ 1., 1.], [ 2., 2.], [ 3., 3.], [ 4., 4.], [ 5., 5.], [ 6., 6.], [ 7., 7.], [ 8., 8.]]

Note that you cannot concat x,y,z along dimension 1 since dimension 0 is not the same for all the input arrays.

concat(y,z,dim=1) = [[ 3., 3., 6., 6.], [ 4., 4., 7., 7.], [ 5., 5., 8., 8.]]

Defined in src/operator/nn/concat.cc:L384

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx._sparse_cosMethod.

_sparse_cos(data)

sparsecos is an alias of cos.

Computes the element-wise cosine of the input array.

The input should be in radians (:math:2\pi rad equals 360 degrees).

.. math:: cos([0, \pi/4, \pi/2]) = [1, 0.707, 0]

The storage type of $cos$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L90

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_coshMethod.

_sparse_cosh(data)

sparsecosh is an alias of cosh.

Returns the hyperbolic cosine of the input array, computed element-wise.

.. math:: cosh(x) = 0.5\times(exp(x) + exp(-x))

The storage type of $cosh$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L409

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_degreesMethod.

_sparse_degrees(data)

sparsedegrees is an alias of degrees.

Converts each element of the input array from radians to degrees.

.. math:: degrees([0, \pi/2, \pi, 3\pi/2, 2\pi]) = [0, 90, 180, 270, 360]

The storage type of $degrees$ output depends upon the input storage type:

  • degrees(default) = default
  • degrees(rowsparse) = rowsparse
  • degrees(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L332

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_dotMethod.

_sparse_dot(lhs, rhs, transpose_a, transpose_b, forward_stype)

sparsedot is an alias of dot.

Dot product of two arrays.

$dot$'s behavior depends on the input array dimensions:

  • 1-D arrays: inner product of vectors
  • 2-D arrays: matrix multiplication
  • N-D arrays: a sum product over the last axis of the first input and the first axis of the second input

    For example, given 3-D $x$ with shape (n,m,k) and $y$ with shape (k,r,s), the result array will have shape (n,m,r,s). It is computed by::

    dot(x,y)[i,j,a,b] = sum(x[i,j,:]*y[:,a,b])

    Example::

    x = reshape([0,1,2,3,4,5,6,7], shape=(2,2,2)) y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2)) dot(x,y)[0,0,1,1] = 0 sum(x[0,0,:]*y[:,1,1]) = 0

The storage type of $dot$ output depends on storage types of inputs, transpose option and forward_stype option for output storage type. Implemented sparse operations include:

  • dot(default, default, transposea=True/False, transposeb=True/False) = default
  • dot(csr, default, transpose_a=True) = default
  • dot(csr, default, transposea=True) = rowsparse
  • dot(csr, default) = default
  • dot(csr, row_sparse) = default
  • dot(default, csr) = csr (CPU only)
  • dot(default, csr, forward_stype='default') = default
  • dot(default, csr, transposeb=True, forwardstype='default') = default

If the combination of input storage types and forward_stype does not match any of the above patterns, $dot$ will fallback and generate output with default storage.

.. Note::

If the storage type of the lhs is "csr", the storage type of gradient w.r.t rhs will be
"row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
and Adam. Note that by default lazy updates is turned on, which may perform differently
from standard updates. For more details, please check the Optimization API at:
https://mxnet.incubator.apache.org/api/python/optimization/optimization.html

Defined in src/operator/tensor/dot.cc:L77

Arguments

  • lhs::NDArray-or-SymbolicNode: The first input
  • rhs::NDArray-or-SymbolicNode: The second input
  • transpose_a::boolean, optional, default=0: If true then transpose the first input before dot.
  • transpose_b::boolean, optional, default=0: If true then transpose the second input before dot.
  • forward_stype::{None, 'csr', 'default', 'row_sparse'},optional, default='None': The desired storage type of the forward output given by user, if thecombination of input storage types and this hint does not matchany implemented ones, the dot operator will perform fallback operationand still produce an output of the desired storage type.

source

# MXNet.mx._sparse_elemwise_addMethod.

_sparse_elemwise_add(lhs, rhs)

sparseelemwiseadd is an alias of elemwiseadd.

Adds arguments element-wise.

The storage type of $elemwise_add$ output depends on storage types of inputs

  • elemwiseadd(rowsparse, rowsparse) = rowsparse
  • elemwise_add(csr, csr) = csr
  • elemwise_add(default, csr) = default
  • elemwise_add(csr, default) = default
  • elemwise_add(default, rsp) = default
  • elemwise_add(rsp, default) = default
  • otherwise, $elemwise_add$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._sparse_elemwise_divMethod.

_sparse_elemwise_div(lhs, rhs)

sparseelemwisediv is an alias of elemwisediv.

Divides arguments element-wise.

The storage type of $elemwise_div$ output is always dense

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._sparse_elemwise_mulMethod.

_sparse_elemwise_mul(lhs, rhs)

sparseelemwisemul is an alias of elemwisemul.

Multiplies arguments element-wise.

The storage type of $elemwise_mul$ output depends on storage types of inputs

  • elemwise_mul(default, default) = default
  • elemwisemul(rowsparse, rowsparse) = rowsparse
  • elemwisemul(default, rowsparse) = row_sparse
  • elemwisemul(rowsparse, default) = row_sparse
  • elemwise_mul(csr, csr) = csr
  • otherwise, $elemwise_mul$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._sparse_elemwise_subMethod.

_sparse_elemwise_sub(lhs, rhs)

sparseelemwisesub is an alias of elemwisesub.

Subtracts arguments element-wise.

The storage type of $elemwise_sub$ output depends on storage types of inputs

  • elemwisesub(rowsparse, rowsparse) = rowsparse
  • elemwise_sub(csr, csr) = csr
  • elemwise_sub(default, csr) = default
  • elemwise_sub(csr, default) = default
  • elemwise_sub(default, rsp) = default
  • elemwise_sub(rsp, default) = default
  • otherwise, $elemwise_sub$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx._sparse_expMethod.

_sparse_exp(data)

sparseexp is an alias of exp.

Returns element-wise exponential value of the input.

.. math:: exp(x) = e^x \approx 2.718^x

Example::

exp([0, 1, 2]) = [1., 2.71828175, 7.38905621]

The storage type of $exp$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L64

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_expm1Method.

_sparse_expm1(data)

sparseexpm1 is an alias of expm1.

Returns $exp(x) - 1$ computed element-wise on the input.

This function provides greater precision than $exp(x) - 1$ for small values of $x$.

The storage type of $expm1$ output depends upon the input storage type:

  • expm1(default) = default
  • expm1(rowsparse) = rowsparse
  • expm1(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L244

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_fixMethod.

_sparse_fix(data)

sparsefix is an alias of fix.

Returns element-wise rounded value to the nearest integer towards zero of the input.

Example::

fix([-2.1, -1.9, 1.9, 2.1]) = [-2., -1., 1., 2.]

The storage type of $fix$ output depends upon the input storage type:

  • fix(default) = default
  • fix(rowsparse) = rowsparse
  • fix(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L874

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_floorMethod.

_sparse_floor(data)

sparsefloor is an alias of floor.

Returns element-wise floor of the input.

The floor of the scalar x is the largest integer i, such that i <= x.

Example::

floor([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-3., -2., 1., 1., 2.]

The storage type of $floor$ output depends upon the input storage type:

  • floor(default) = default
  • floor(rowsparse) = rowsparse
  • floor(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L836

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_ftrl_updateMethod.

_sparse_ftrl_update(weight, grad, z, n, lr, lamda1, beta, wd, rescale_grad, clip_gradient)

sparseftrlupdate is an alias of ftrlupdate.

Update function for Ftrl optimizer. Referenced from Ad Click Prediction: a View from the Trenches, available at http://dl.acm.org/citation.cfm?id=2488200.

It updates the weights using::

rescaledgrad = clip(grad * rescalegrad, clipgradient) z += rescaledgrad - (sqrt(n + rescaledgrad2) - sqrt(n)) * weight / learningrate n += rescaledgrad2 w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learningrate + wd) * (abs(z) > lamda1)

If w, z and n are all of $row_sparse$ storage type, only the row slices whose indices appear in grad.indices are updated (for w, z and n)::

for row in grad.indices: rescaledgrad[row] = clip(grad[row] * rescalegrad, clipgradient) z[row] += rescaledgrad[row] - (sqrt(n[row] + rescaledgrad[row]2) - sqrt(n[row])) * weight[row] / learningrate n[row] += rescaledgrad[row]2 w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learningrate + wd) * (abs(z[row]) > lamda1)

Defined in src/operator/optimizer_op.cc:L875

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • z::NDArray-or-SymbolicNode: z
  • n::NDArray-or-SymbolicNode: Square of grad
  • lr::float, required: Learning rate
  • lamda1::float, optional, default=0.00999999978: The L1 regularization coefficient.
  • beta::float, optional, default=1: Per-Coordinate Learning Rate beta.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx._sparse_gammaMethod.

_sparse_gamma(data)

sparsegamma is an alias of gamma.

Returns the gamma function (extension of the factorial function to the reals), computed element-wise on the input array.

The storage type of $gamma$ output is always dense

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_gammalnMethod.

_sparse_gammaln(data)

sparsegammaln is an alias of gammaln.

Returns element-wise log of the absolute value of the gamma function of the input.

The storage type of $gammaln$ output is always dense

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_logMethod.

_sparse_log(data)

sparselog is an alias of log.

Returns element-wise Natural logarithmic value of the input.

The natural logarithm is logarithm in base e, so that $log(exp(x)) = x$

The storage type of $log$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L77

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_log10Method.

_sparse_log10(data)

sparselog10 is an alias of log10.

Returns element-wise Base-10 logarithmic value of the input.

$10**log10(x) = x$

The storage type of $log10$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L94

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_log1pMethod.

_sparse_log1p(data)

sparselog1p is an alias of log1p.

Returns element-wise $log(1 + x)$ value of the input.

This function is more accurate than $log(1 + x)$ for small $x$ so that :math:1+x\approx 1

The storage type of $log1p$ output depends upon the input storage type:

  • log1p(default) = default
  • log1p(rowsparse) = rowsparse
  • log1p(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L199

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_log2Method.

_sparse_log2(data)

sparselog2 is an alias of log2.

Returns element-wise Base-2 logarithmic value of the input.

$2**log2(x) = x$

The storage type of $log2$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_logexp.cc:L106

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_make_lossMethod.

_sparse_make_loss(data)

sparsemakeloss is an alias of makeloss.

Make your own loss function in network construction.

This operator accepts a customized loss function symbol as a terminal loss and the symbol should be an operator with no backward dependency. The output of this function is the gradient of loss with respect to the input data.

For example, if you are a making a cross entropy loss function. Assume $out$ is the predicted output and $label$ is the true label, then the cross entropy can be defined as::

crossentropy = label * log(out) + (1 - label) * log(1 - out) loss = makeloss(cross_entropy)

We will need to use $make_loss$ when we are creating our own loss function or we want to combine multiple loss functions. Also we may want to stop some variables' gradients from backpropagation. See more detail in $BlockGrad$ or $stop_gradient$.

The storage type of $make_loss$ output depends upon the input storage type:

  • make_loss(default) = default
  • makeloss(rowsparse) = row_sparse

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L358

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_meanMethod.

_sparse_mean(data, axis, keepdims, exclude)

sparsemean is an alias of mean.

Computes the mean of array elements over given axes.

Defined in src/operator/tensor/./broadcastreduceop.h:L83

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction.

    `` The default,axis=(), will compute over all elements into a scalar array with shape(1,)`.

    If axis is int, a reduction is performed on a particular axis.

    If axis is a tuple of ints, a reduction is performed on all the axes specified in the tuple.

    If exclude is true, reduction will be performed on the axes that are NOT in axis instead.

    Negative values means indexing from right to left. `` *keepdims::boolean, optional, default=0: If this is set toTrue, the reduced axes are left in the result as dimension with size one. *exclude::boolean, optional, default=0`: Whether to perform reduction on axis that are NOT in axis instead.

source

# MXNet.mx._sparse_negativeMethod.

_sparse_negative(data)

sparsenegative is an alias of negative.

Numerical negative of the argument, element-wise.

The storage type of $negative$ output depends upon the input storage type:

  • negative(default) = default
  • negative(rowsparse) = rowsparse
  • negative(csr) = csr

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_normMethod.

_sparse_norm(data, ord, axis, out_dtype, keepdims)

sparsenorm is an alias of norm.

Computes the norm on an NDArray.

This operator computes the norm on an NDArray with the specified axis, depending on the value of the ord parameter. By default, it computes the L2 norm on the entire array. Currently only ord=2 supports sparse ndarrays.

Examples::

x = [[[1, 2], [3, 4]], [[2, 2], [5, 6]]]

norm(x, ord=2, axis=1) = [[3.1622777 4.472136 ] [5.3851647 6.3245554]]

norm(x, ord=1, axis=1) = [[4., 6.], [7., 8.]]

rsp = x.caststorage('rowsparse')

norm(rsp) = [5.47722578]

csr = x.cast_storage('csr')

norm(csr) = [5.47722578]

Defined in src/operator/tensor/broadcastreducenorm_value.cc:L88

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • ord::int, optional, default='2': Order of the norm. Currently ord=1 and ord=2 is supported.
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction. The default, axis=(), will compute over all elements into a scalar array with shape (1,). If axis is int, a reduction is performed on a particular axis. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed.
  • out_dtype::{None, 'float16', 'float32', 'float64', 'int32', 'int64', 'int8'},optional, default='None': The data type of the output.
  • keepdims::boolean, optional, default=0: If this is set to True, the reduced axis is left in the result as dimension with size one.

source

# MXNet.mx._sparse_radiansMethod.

_sparse_radians(data)

sparseradians is an alias of radians.

Converts each element of the input array from degrees to radians.

.. math:: radians([0, 90, 180, 270, 360]) = [0, \pi/2, \pi, 3\pi/2, 2\pi]

The storage type of $radians$ output depends upon the input storage type:

  • radians(default) = default
  • radians(rowsparse) = rowsparse
  • radians(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L351

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_reluMethod.

_sparse_relu(data)

sparserelu is an alias of relu.

Computes rectified linear activation.

.. math:: max(features, 0)

The storage type of $relu$ output depends upon the input storage type:

  • relu(default) = default
  • relu(rowsparse) = rowsparse
  • relu(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L85

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_retainMethod.

_sparse_retain(data, indices)

Pick rows specified by user input index array from a row sparse matrix and save them in the output sparse matrix.

Example::

data = [[1, 2], [3, 4], [5, 6]] indices = [0, 1, 3] shape = (4, 2) rspin = rowsparsearray(data, indices) toretain = [0, 3] rspout = retain(rspin, toretain) rspout.data = [[1, 2], [5, 6]] rsp_out.indices = [0, 3]

The storage type of $retain$ output depends on storage types of inputs

  • retain(rowsparse, default) = rowsparse
  • otherwise, $retain$ is not supported

Defined in src/operator/tensor/sparse_retain.cc:L53

Arguments

  • data::NDArray-or-SymbolicNode: The input array for sparse_retain operator.
  • indices::NDArray-or-SymbolicNode: The index array of rows ids that will be retained.

source

# MXNet.mx._sparse_rintMethod.

_sparse_rint(data)

sparserint is an alias of rint.

Returns element-wise rounded value to the nearest integer of the input.

.. note::

  • For input $n.5$ $rint$ returns $n$ while $round$ returns $n+1$.
  • For input $-n.5$ both $rint$ and $round$ returns $-n-1$.

Example::

rint([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 1., -2., 2., 2.]

The storage type of $rint$ output depends upon the input storage type:

  • rint(default) = default
  • rint(rowsparse) = rowsparse
  • rint(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L798

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_roundMethod.

_sparse_round(data)

sparseround is an alias of round.

Returns element-wise rounded value to the nearest integer of the input.

Example::

round([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 2., -2., 2., 2.]

The storage type of $round$ output depends upon the input storage type:

  • round(default) = default
  • round(rowsparse) = rowsparse
  • round(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L777

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_rsqrtMethod.

_sparse_rsqrt(data)

sparsersqrt is an alias of rsqrt.

Returns element-wise inverse square-root value of the input.

.. math:: rsqrt(x) = 1/\sqrt{x}

Example::

rsqrt([4,9,16]) = [0.5, 0.33333334, 0.25]

The storage type of $rsqrt$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L221

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_sgd_mom_updateMethod.

_sparse_sgd_mom_update(weight, grad, mom, lr, momentum, wd, rescale_grad, clip_gradient, lazy_update)

sparsesgdmomupdate is an alias of sgdmomupdate.

Momentum update function for Stochastic Gradient Descent (SGD) optimizer.

Momentum update has better convergence rates on neural networks. Mathematically it looks like below:

.. math::

v1 = \alpha * \nabla J(W0)\ vt = \gamma v - \alpha * \nabla J(W{t-1})\ Wt = W{t-1} + vt

It updates the weights using::

v = momentum * v - learning_rate * gradient weight += v

Where the parameter $momentum$ is the decay rate of momentum estimates at each epoch.

However, if grad's storage type is $row_sparse$, $lazy_update$ is True and weight's storage type is the same as momentum's storage type, only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::

for row in gradient.indices: v[row] = momentum[row] * v[row] - learning_rate * gradient[row] weight[row] += v[row]

Defined in src/operator/optimizer_op.cc:L564

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mom::NDArray-or-SymbolicNode: Momentum
  • lr::float, required: Learning rate
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse and both weight and momentum have the same stype

source

# MXNet.mx._sparse_sgd_updateMethod.

_sparse_sgd_update(weight, grad, lr, wd, rescale_grad, clip_gradient, lazy_update)

sparsesgdupdate is an alias of sgdupdate.

Update function for Stochastic Gradient Descent (SGD) optimizer.

It updates the weights using::

weight = weight - learning_rate * (gradient + wd * weight)

However, if gradient is of $row_sparse$ storage type and $lazy_update$ is True, only the row slices whose indices appear in grad.indices are updated::

for row in gradient.indices: weight[row] = weight[row] - learning_rate * (gradient[row] + wd * weight[row])

Defined in src/operator/optimizer_op.cc:L523

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • lr::float, required: Learning rate
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse.

source

# MXNet.mx._sparse_sigmoidMethod.

_sparse_sigmoid(data)

sparsesigmoid is an alias of sigmoid.

Computes sigmoid of x element-wise.

.. math:: y = 1 / (1 + exp(-x))

The storage type of $sigmoid$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L119

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_signMethod.

_sparse_sign(data)

sparsesign is an alias of sign.

Returns element-wise sign of the input.

Example::

sign([-2, 0, 3]) = [-1, 0, 1]

The storage type of $sign$ output depends upon the input storage type:

  • sign(default) = default
  • sign(rowsparse) = rowsparse
  • sign(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L758

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_sinMethod.

_sparse_sin(data)

sparsesin is an alias of sin.

Computes the element-wise sine of the input array.

The input should be in radians (:math:2\pi rad equals 360 degrees).

.. math:: sin([0, \pi/4, \pi/2]) = [0, 0.707, 1]

The storage type of $sin$ output depends upon the input storage type:

  • sin(default) = default
  • sin(rowsparse) = rowsparse
  • sin(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L47

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_sinhMethod.

_sparse_sinh(data)

sparsesinh is an alias of sinh.

Returns the hyperbolic sine of the input array, computed element-wise.

.. math:: sinh(x) = 0.5\times(exp(x) - exp(-x))

The storage type of $sinh$ output depends upon the input storage type:

  • sinh(default) = default
  • sinh(rowsparse) = rowsparse
  • sinh(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L371

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_sliceMethod.

_sparse_slice(data, begin, end, step)

sparseslice is an alias of slice.

Slices a region of the array. .. note:: $crop$ is deprecated. Use $slice$ instead. This function returns a sliced array between the indices given by begin and end with the corresponding step. For an input array of $shape=(d_0, d_1, ..., d_n-1)$, slice operation with $begin=(b_0, b_1...b_m-1)$, $end=(e_0, e_1, ..., e_m-1)$, and $step=(s_0, s_1, ..., s_m-1)$, where m <= n, results in an array with the shape $(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)$. The resulting array's k-th dimension contains elements from the k-th dimension of the input array starting from index $b_k$ (inclusive) with step $s_k$ until reaching $e_k$ (exclusive). If the k-th elements are None in the sequence of begin, end, and step, the following rule will be used to set default values. If s_k is None, set s_k=1. If s_k > 0, set b_k=0, e_k=d_k; else, set b_k=d_k-1, e_k=-1. The storage type of $slice$ output depends on storage types of inputs

  • slice(csr) = csr
  • otherwise, $slice$ generates output with default storage

.. note:: When input data storage type is csr, it only supports step=(), or step=(None,), or step=(1,) to generate a csr output. For other step parameter values, it falls back to slicing a dense tensor. Example:: x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] slice(x, begin=(0,1), end=(2,4)) = [[ 2., 3., 4.], [ 6., 7., 8.]] slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.], [5., 7.], [1., 3.]]

Defined in src/operator/tensor/matrix_op.cc:L481

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx._sparse_sqrtMethod.

_sparse_sqrt(data)

sparsesqrt is an alias of sqrt.

Returns element-wise square-root value of the input.

.. math:: \textrm{sqrt}(x) = \sqrt{x}

Example::

sqrt([4, 9, 16]) = [2, 3, 4]

The storage type of $sqrt$ output depends upon the input storage type:

  • sqrt(default) = default
  • sqrt(rowsparse) = rowsparse
  • sqrt(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L170

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_squareMethod.

_sparse_square(data)

sparsesquare is an alias of square.

Returns element-wise squared value of the input.

.. math:: square(x) = x^2

Example::

square([2, 3, 4]) = [4, 9, 16]

The storage type of $square$ output depends upon the input storage type:

  • square(default) = default
  • square(rowsparse) = rowsparse
  • square(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L119

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_stop_gradientMethod.

_sparse_stop_gradient(data)

sparsestop_gradient is an alias of BlockGrad.

Stops gradient computation.

Stops the accumulated gradient of the inputs from flowing through this operator in the backward direction. In other words, this operator prevents the contribution of its inputs to be taken into account for computing gradients.

Example::

v1 = [1, 2] v2 = [0, 1] a = Variable('a') b = Variable('b') bstopgrad = stopgradient(3 * b) loss = MakeLoss(bstop_grad + a)

executor = loss.simplebind(ctx=cpu(), a=(1,2), b=(1,2)) executor.forward(istrain=True, a=v1, b=v2) executor.outputs [ 1. 5.]

executor.backward() executor.grad_arrays [ 0. 0.] [ 1. 1.]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L325

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_sumMethod.

_sparse_sum(data, axis, keepdims, exclude)

sparsesum is an alias of sum.

Computes the sum of array elements over given axes.

.. Note::

sum and sum_axis are equivalent. For ndarray of csr storage type summation along axis 0 and axis 1 is supported. Setting keepdims or exclude to True will cause a fallback to dense operator.

Example::

data = [[[1, 2], [2, 3], [1, 3]], [[1, 4], [4, 3], [5, 2]], [[7, 1], [7, 2], [7, 3]]]

sum(data, axis=1) [[ 4. 8.] [ 10. 9.] [ 21. 6.]]

sum(data, axis=[1,2]) [ 12. 19. 27.]

data = [[1, 2, 0], [3, 0, 1], [4, 1, 0]]

csr = cast_storage(data, 'csr')

sum(csr, axis=0) [ 8. 3. 1.]

sum(csr, axis=1) [ 3. 4. 5.]

Defined in src/operator/tensor/broadcastreducesum_value.cc:L66

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction.

    `` The default,axis=(), will compute over all elements into a scalar array with shape(1,)`.

    If axis is int, a reduction is performed on a particular axis.

    If axis is a tuple of ints, a reduction is performed on all the axes specified in the tuple.

    If exclude is true, reduction will be performed on the axes that are NOT in axis instead.

    Negative values means indexing from right to left. `` *keepdims::boolean, optional, default=0: If this is set toTrue, the reduced axes are left in the result as dimension with size one. *exclude::boolean, optional, default=0`: Whether to perform reduction on axis that are NOT in axis instead.

source

# MXNet.mx._sparse_tanMethod.

_sparse_tan(data)

sparsetan is an alias of tan.

Computes the element-wise tangent of the input array.

The input should be in radians (:math:2\pi rad equals 360 degrees).

.. math:: tan([0, \pi/4, \pi/2]) = [0, 1, -inf]

The storage type of $tan$ output depends upon the input storage type:

  • tan(default) = default
  • tan(rowsparse) = rowsparse
  • tan(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L140

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_tanhMethod.

_sparse_tanh(data)

sparsetanh is an alias of tanh.

Returns the hyperbolic tangent of the input array, computed element-wise.

.. math:: tanh(x) = sinh(x) / cosh(x)

The storage type of $tanh$ output depends upon the input storage type:

  • tanh(default) = default
  • tanh(rowsparse) = rowsparse
  • tanh(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L451

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_truncMethod.

_sparse_trunc(data)

sparsetrunc is an alias of trunc.

Return the element-wise truncated value of the input.

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is. In short, the fractional part of the signed number x is discarded.

Example::

trunc([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-2., -1., 1., 1., 2.]

The storage type of $trunc$ output depends upon the input storage type:

  • trunc(default) = default
  • trunc(rowsparse) = rowsparse
  • trunc(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L856

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx._sparse_whereMethod.

_sparse_where(condition, x, y)

sparsewhere is an alias of where.

Return the elements, either from x or y, depending on the condition.

Given three ndarrays, condition, x, and y, return an ndarray with the elements from x or y, depending on the elements from condition are true or false. x and y must have the same shape. If condition has the same shape as x, each element in the output array is from x if the corresponding element in the condition is true, and from y if false.

If condition does not have the same shape as x, it must be a 1D array whose size is the same as x's first dimension size. Each row of the output array is from x's row if the corresponding element from condition is true, and from y's row if false.

Note that all non-zero values are interpreted as $True$ in condition.

Examples::

x = [[1, 2], [3, 4]] y = [[5, 6], [7, 8]] cond = [[0, 1], [-1, 0]]

where(cond, x, y) = [[5, 2], [3, 8]]

csrcond = caststorage(cond, 'csr')

where(csr_cond, x, y) = [[5, 2], [3, 8]]

Defined in src/operator/tensor/controlflowop.cc:L56

Arguments

  • condition::NDArray-or-SymbolicNode: condition array
  • x::NDArray-or-SymbolicNode:
  • y::NDArray-or-SymbolicNode:

source

# MXNet.mx._sparse_zeros_likeMethod.

_sparse_zeros_like(data)

sparsezeroslike is an alias of zeroslike.

Return an array of zeros with the same shape, type and storage type as the input array.

The storage type of $zeros_like$ output depends on the storage type of the input

  • zeroslike(rowsparse) = row_sparse
  • zeros_like(csr) = csr
  • zeros_like(default) = default

Examples::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

zeros_like(x) = [[ 0., 0., 0.], [ 0., 0., 0.]]

Arguments

  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx._split_v2Method.

_split_v2(data, indices, axis, squeeze_axis, sections)

Splits an array along a particular axis into multiple sub-arrays. Example:: x = [[[ 1.] [ 2.]] [[ 3.] [ 4.]] [[ 5.] [ 6.]]] x.shape = (3, 2, 1) y = splitv2(x, axis=1, indicesorsections=2) // a list of 2 arrays with shape (3, 1, 1) y = [[[ 1.]] [[ 3.]] [[ 5.]]] [[[ 2.]] [[ 4.]] [[ 6.]]] y[0].shape = (3, 1, 1) z = splitv2(x, axis=0, indicesorsections=3) // a list of 3 arrays with shape (1, 2, 1) z = [[[ 1.] [ 2.]]] [[[ 3.] [ 4.]]] [[[ 5.] [ 6.]]] z[0].shape = (1, 2, 1) w = splitv2(x, axis=0, indicesorsections=(1,)) // a list of 2 arrays with shape [(1, 2, 1), (2, 2, 1)] w = [[[ 1.] [ 2.]]] [[[3.] [4.]] [[5.] [6.]]] w[0].shape = (1, 2, 1) w[1].shape = (2, 2, 1) squeeze*axis=Trueremoves the axis with length 1 from the shapes of the output arrays. Note that settingsqueeze*axisto1removes axis with length 1 only along theaxiswhich it is split. Alsosqueeze*axiscan be set to true only ifinput.shape[axis] == indices_or_sections. Example:: z = splitv2(x, axis=0, indicesorsections=3, squeeze*axis=1) // a list of 3 arrays with shape (2, 1) z = [[ 1.] [ 2.]] [[ 3.] [ 4.]] [[ 5.] [ 6.]] z[0].shape = (2, 1)

Defined in src/operator/tensor/matrix_op.cc:L1087

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • indices::Shape(tuple), required: Indices of splits. The elements should denote the boundaries of at which split is performed along the axis.
  • axis::int, optional, default='1': Axis along which to split.
  • squeeze_axis::boolean, optional, default=0: If true, Removes the axis with length 1 from the shapes of the output arrays. Note that setting squeeze_axis to $true$ removes axis with length 1 only along the axis which it is split. Also squeeze_axis can be set to $true$ only if $input.shape[axis] == num_outputs$.
  • sections::int, optional, default='0': Number of sections if equally splitted. Default to 0 which means split by indices.

source

# MXNet.mx._split_v2_backwardMethod.

_split_v2_backward()

Arguments

source

# MXNet.mx._square_sumMethod.

_square_sum(data, axis, keepdims, exclude)

Computes the square sum of array elements over a given axis for row-sparse matrix. This is a temporary solution for fusing ops square and sum together for row-sparse matrix to save memory for storing gradients. It will become deprecated once the functionality of fusing operators is finished in the future.

Example::

dns = mx.nd.array([[0, 0], [1, 2], [0, 0], [3, 4], [0, 0]]) rsp = dns.tostype('rowsparse') sum = mx.nd.internal.squaresum(rsp, axis=1) sum = [0, 5, 0, 25, 0]

Defined in src/operator/tensor/square_sum.cc:L63

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction.

    `` The default,axis=(), will compute over all elements into a scalar array with shape(1,)`.

    If axis is int, a reduction is performed on a particular axis.

    If axis is a tuple of ints, a reduction is performed on all the axes specified in the tuple.

    If exclude is true, reduction will be performed on the axes that are NOT in axis instead.

    Negative values means indexing from right to left. `` *keepdims::boolean, optional, default=0: If this is set toTrue, the reduced axes are left in the result as dimension with size one. *exclude::boolean, optional, default=0`: Whether to perform reduction on axis that are NOT in axis instead.

source

# MXNet.mx._unravel_indexMethod.

_unravel_index(data, shape)

Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix. The leading dimension may be left unspecified by using -1 as placeholder.

Examples::

A = [22,41,37] unravel(A, shape=(7,6)) = [[3,6,6],[4,5,1]] unravel(A, shape=(-1,6)) = [[3,6,6],[4,5,1]]

Defined in src/operator/tensor/ravel.cc:L67

Arguments

  • data::NDArray-or-SymbolicNode: Array of flat indices
  • shape::Shape(tuple), optional, default=None: Shape of the array into which the multi-indices apply.

source

# MXNet.mx._while_loopMethod.

_while_loop(cond, func, data, num_args, num_outputs, num_out_data, max_iterations, cond_input_locs, func_input_locs, func_var_locs)

Note: whileloop takes variable number of positional inputs. So instead of calling as whileloop([x, y, z], numargs=3), one should call via _whileloop(x, y, z), and num_args will be determined automatically.

Run a while loop over with user-defined condition and computation

From:src/operator/control_flow.cc:1151

Arguments

  • cond::SymbolicNode: Input graph for the loop condition.
  • func::SymbolicNode: Input graph for the loop body.
  • data::NDArray-or-SymbolicNode[]: The input arrays that include data arrays and states.
  • num_args::int, required: Number of input arguments, including cond and func as two symbol inputs.
  • num_outputs::int, required: The number of outputs of the subgraph.
  • num_out_data::int, required: The number of outputs from the function body.
  • max_iterations::int, required: Maximum number of iterations.
  • cond_input_locs::tuple of <long>, required: The locations of cond's inputs in the given inputs.
  • func_input_locs::tuple of <long>, required: The locations of func's inputs in the given inputs.
  • func_var_locs::tuple of <long>, required: The locations of loop_vars among func's inputs.

source

# MXNet.mx._zeros_without_dtypeMethod.

_zeros_without_dtype(shape, ctx, dtype)

fill target with zeros without default dtype

Arguments

  • shape::Shape(tuple), optional, default=None: The shape of the output
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned.Only used for imperative calls.
  • dtype::int, optional, default='-1': Target data type.

source

# MXNet.mx.adam_updateMethod.

adam_update(weight, grad, mean, var, lr, beta1, beta2, epsilon, wd, rescale_grad, clip_gradient, lazy_update)

Update function for Adam optimizer. Adam is seen as a generalization of AdaGrad.

Adam update consists of the following steps, where g represents gradient and m, v are 1st and 2nd order moment estimates (mean and variance).

.. math::

gt = \nabla J(W)\ mt = \beta1 m{t-1} + (1 - \beta1) gt\ vt = \beta2 v + (1 - \beta2) gt^2\ Wt = W - \alpha \frac{ mt }{ \sqrt{ vt } + \epsilon }

It updates the weights using::

m = beta1m + (1-beta1)grad v = beta2v + (1-beta2)(grad**2) w += - learning_rate * m / (sqrt(v) + epsilon)

However, if grad's storage type is $row_sparse$, $lazy_update$ is True and the storage type of weight is the same as those of m and v, only the row slices whose indices appear in grad.indices are updated (for w, m and v)::

for row in grad.indices: m[row] = beta1m[row] + (1-beta1)grad[row] v[row] = beta2v[row] + (1-beta2)(grad[row]**2) w[row] += - learning_rate * m[row] / (sqrt(v[row]) + epsilon)

Defined in src/operator/optimizer_op.cc:L687

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mean::NDArray-or-SymbolicNode: Moving mean
  • var::NDArray-or-SymbolicNode: Moving variance
  • lr::float, required: Learning rate
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse and all of w, m and v have the same stype

source

# MXNet.mx.add_nMethod.

add_n(args)

Note: addn takes variable number of positional inputs. So instead of calling as addn([x, y, z], numargs=3), one should call via addn(x, y, z), and num_args will be determined automatically.

Adds all input arguments element-wise.

.. math:: add_n(a1, a2, ..., an) = a1 + a2 + ... + an

$add_n$ is potentially more efficient than calling $add$ by n times.

The storage type of $add_n$ output depends on storage types of inputs

  • addn(rowsparse, rowsparse, ..) = rowsparse
  • add_n(default, csr, default) = default
  • add_n(any input combinations longer than 4 (>4) with at least one default type) = default
  • otherwise, $add_n$ falls all inputs back to default storage and generates default storage

Defined in src/operator/tensor/elemwise_sum.cc:L155

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input arguments

source

# MXNet.mx.all_finiteMethod.

all_finite(data, init_output)

Check if all the float numbers in the array are finite (used for AMP)

Defined in src/operator/contrib/all_finite.cc:L100

Arguments

  • data::NDArray: Array
  • init_output::boolean, optional, default=1: Initialize output to 1.

source

# MXNet.mx.amp_castMethod.

amp_cast(data, dtype)

Cast function between low precision float/FP32 used by AMP.

It casts only between low precision float/FP32 and does not do anything for other types.

Defined in src/operator/tensor/amp_cast.cc:L125

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, required: Output data type.

source

# MXNet.mx.amp_multicastMethod.

amp_multicast(data, num_outputs, cast_narrow)

Cast function used by AMP, that casts its inputs to the common widest type.

It casts only between low precision float/FP32 and does not do anything for other types.

Defined in src/operator/tensor/amp_cast.cc:L169

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights
  • num_outputs::int, required: Number of input/output pairs to be casted to the widest type.
  • cast_narrow::boolean, optional, default=0: Whether to cast to the narrowest type

source

# MXNet.mx.argmax_channelMethod.

argmax_channel(data)

Returns argmax indices of each channel from the input array.

The result will be an NDArray of shape (num_channel,).

In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

Examples::

x = [[ 0., 1., 2.], [ 3., 4., 5.]]

argmax_channel(x) = [ 2., 2.]

Defined in src/operator/tensor/broadcastreduceop_index.cc:L96

Arguments

  • data::NDArray-or-SymbolicNode: The input array

source

# MXNet.mx.argsortMethod.

argsort(data, axis, is_ascend, dtype)

Returns the indices that would sort an input array along the given axis.

This function performs sorting along the given axis and returns an array of indices having same shape as an input array that index data in sorted order.

Examples::

x = [[ 0.3, 0.2, 0.4], [ 0.1, 0.3, 0.2]]

// sort along axis -1 argsort(x) = [[ 1., 0., 2.], [ 0., 2., 1.]]

// sort along axis 0 argsort(x, axis=0) = [[ 1., 0., 1.] [ 0., 1., 0.]]

// flatten and then sort argsort(x, axis=None) = [ 3., 1., 5., 0., 4., 2.]

Defined in src/operator/tensor/ordering_op.cc:L184

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • axis::int or None, optional, default='-1': Axis along which to sort the input tensor. If not given, the flattened array is used. Default is -1.
  • is_ascend::boolean, optional, default=1: Whether to sort in ascending or descending order.
  • dtype::{'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'},optional, default='float32': DType of the output indices. It is only valid when ret_typ is "indices" or "both". An error will be raised if the selected data type cannot precisely represent the indices.

source

# MXNet.mx.batch_dotMethod.

batch_dot(lhs, rhs, transpose_a, transpose_b, forward_stype)

Batchwise dot product.

$batch_dot$ is used to compute dot product of $x$ and $y$ when $x$ and $y$ are data in batch, namely N-D (N >= 3) arrays in shape of (B0, ..., B_i, :, :).

For example, given $x$ with shape (B_0, ..., B_i, N, M) and $y$ with shape (B_0, ..., B_i, M, K), the result array will have shape (B_0, ..., B_i, N, K), which is computed by::

batchdot(x,y)[b0, ..., bi, :, :] = dot(x[b0, ..., bi, :, :], y[b0, ..., b_i, :, :])

Defined in src/operator/tensor/dot.cc:L127

Arguments

  • lhs::NDArray-or-SymbolicNode: The first input
  • rhs::NDArray-or-SymbolicNode: The second input
  • transpose_a::boolean, optional, default=0: If true then transpose the first input before dot.
  • transpose_b::boolean, optional, default=0: If true then transpose the second input before dot.
  • forward_stype::{None, 'csr', 'default', 'row_sparse'},optional, default='None': The desired storage type of the forward output given by user, if thecombination of input storage types and this hint does not matchany implemented ones, the dot operator will perform fallback operationand still produce an output of the desired storage type.

source

# MXNet.mx.batch_takeMethod.

batch_take(a, indices)

Takes elements from a data batch.

.. note:: batch_take is deprecated. Use pick instead.

Given an input array of shape $(d0, d1)$ and indices of shape $(i0,)$, the result will be an output array of shape $(i0,)$ with::

output[i] = input[i, indices[i]]

Examples::

x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]]

// takes elements with specified indices batch_take(x, [0,1,0]) = [ 1. 4. 5.]

Defined in src/operator/tensor/indexing_op.cc:L835

Arguments

  • a::NDArray-or-SymbolicNode: The input array
  • indices::NDArray-or-SymbolicNode: The index array

source

# MXNet.mx.broadcast_likeMethod.

broadcast_like(lhs, rhs, lhs_axes, rhs_axes)

Broadcasts lhs to have the same shape as rhs.

Broadcasting is a mechanism that allows NDArrays to perform arithmetic operations with arrays of different shapes efficiently without creating multiple copies of arrays. Also see, Broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>_ for more explanation.

Broadcasting is allowed on axes with size 1, such as from (2,1,3,1) to (2,8,3,9). Elements will be duplicated on the broadcasted axes.

For example::

broadcast_like([[1,2,3]], [[5,6,7],[7,8,9]]) = [[ 1., 2., 3.], [ 1., 2., 3.]])

broadcastlike([9], [1,2,3,4,5], lhsaxes=(0,), rhs_axes=(-1,)) = [9,9,9,9,9]

Defined in src/operator/tensor/broadcastreduceop_value.cc:L178

Arguments

  • lhs::NDArray-or-SymbolicNode: First input.
  • rhs::NDArray-or-SymbolicNode: Second input.
  • lhs_axes::Shape or None, optional, default=None: Axes to perform broadcast on in the first input array
  • rhs_axes::Shape or None, optional, default=None: Axes to copy from the second input array

source

# MXNet.mx.broadcast_logical_andMethod.

broadcast_logical_and(lhs, rhs)

Returns the result of element-wise logical and with broadcasting.

Example::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

y = [[ 0.], [ 1.]]

broadcastlogicaland(x, y) = [[ 0., 0., 0.], [ 1., 1., 1.]]

Defined in src/operator/tensor/elemwisebinarybroadcastoplogic.cc:L153

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx.broadcast_logical_orMethod.

broadcast_logical_or(lhs, rhs)

Returns the result of element-wise logical or with broadcasting.

Example::

x = [[ 1., 1., 0.], [ 1., 1., 0.]]

y = [[ 1.], [ 0.]]

broadcastlogicalor(x, y) = [[ 1., 1., 1.], [ 1., 1., 0.]]

Defined in src/operator/tensor/elemwisebinarybroadcastoplogic.cc:L171

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx.broadcast_logical_xorMethod.

broadcast_logical_xor(lhs, rhs)

Returns the result of element-wise logical xor with broadcasting.

Example::

x = [[ 1., 1., 0.], [ 1., 1., 0.]]

y = [[ 1.], [ 0.]]

broadcastlogicalxor(x, y) = [[ 0., 0., 1.], [ 1., 1., 0.]]

Defined in src/operator/tensor/elemwisebinarybroadcastoplogic.cc:L189

Arguments

  • lhs::NDArray-or-SymbolicNode: First input to the function
  • rhs::NDArray-or-SymbolicNode: Second input to the function

source

# MXNet.mx.castMethod.

cast(data, dtype)

cast is an alias of Cast.

Casts all elements of the input to a new type.

.. note:: $Cast$ is deprecated. Use $cast$ instead.

Example::

cast([0.9, 1.3], dtype='int32') = [0, 1] cast([1e20, 11.1], dtype='float16') = [inf, 11.09375] cast([300, 11.1, 10.9, -1, -3], dtype='uint8') = [44, 11, 10, 255, 253]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L664

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • dtype::{'bfloat16', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, required: Output data type.

source

# MXNet.mx.cast_storageMethod.

cast_storage(data, stype)

Casts tensor storage type to the new type.

When an NDArray with default storage type is cast to csr or row_sparse storage, the result is compact, which means:

  • for csr, zero values will not be retained
  • for row_sparse, row slices of all zeros will not be retained

The storage type of $cast_storage$ output depends on stype parameter:

  • cast_storage(csr, 'default') = default
  • caststorage(rowsparse, 'default') = default
  • cast_storage(default, 'csr') = csr
  • caststorage(default, 'rowsparse') = row_sparse
  • cast_storage(csr, 'csr') = csr
  • caststorage(rowsparse, 'rowsparse') = rowsparse

Example::

dense = [[ 0.,  1.,  0.],
         [ 2.,  0.,  3.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]]

# cast to row_sparse storage type
rsp = cast_storage(dense, 'row_sparse')
rsp.indices = [0, 1]
rsp.values = [[ 0.,  1.,  0.],
              [ 2.,  0.,  3.]]

# cast to csr storage type
csr = cast_storage(dense, 'csr')
csr.indices = [1, 0, 2]
csr.values = [ 1.,  2.,  3.]
csr.indptr = [0, 1, 3, 3, 3]

Defined in src/operator/tensor/cast_storage.cc:L71

Arguments

  • data::NDArray-or-SymbolicNode: The input.
  • stype::{'csr', 'default', 'row_sparse'}, required: Output storage type.

source

# MXNet.mx.choose_element_0indexMethod.

choose_element_0index(data, index, axis, keepdims, mode)

chooseelement0index is an alias of pick.

Picks elements from an input array according to the input indices along the given axis.

Given an input array of shape $(d0, d1)$ and indices of shape $(i0,)$, the result will be an output array of shape $(i0,)$ with::

output[i] = input[i, indices[i]]

By default, if any index mentioned is too large, it is replaced by the index that addresses the last element along an axis (the clip mode).

This function supports n-dimensional input and (n-1)-dimensional indices arrays.

Examples::

x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]]

// picks elements with specified indices along axis 0 pick(x, y=[0,1], 0) = [ 1., 4.]

// picks elements with specified indices along axis 1 pick(x, y=[0,1,0], 1) = [ 1., 4., 5.]

// picks elements with specified indices along axis 1 using 'wrap' mode // to place indicies that would normally be out of bounds pick(x, y=[2,-1,-2], 1, mode='wrap') = [ 1., 4., 5.]

y = [[ 1.], [ 0.], [ 2.]]

// picks elements with specified indices along axis 1 and dims are maintained pick(x, y, 1, keepdims=True) = [[ 2.], [ 3.], [ 6.]]

Defined in src/operator/tensor/broadcastreduceop_index.cc:L150

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • index::NDArray-or-SymbolicNode: The index array
  • axis::int or None, optional, default='-1': int or None. The axis to picking the elements. Negative values means indexing from right to left. If is None, the elements in the index w.r.t the flattened input will be picked.
  • keepdims::boolean, optional, default=0: If true, the axis where we pick the elements is left in the result as dimension with size one.
  • mode::{'clip', 'wrap'},optional, default='clip': Specify how out-of-bound indices behave. Default is "clip". "clip" means clip to the range. So, if all indices mentioned are too large, they are replaced by the index that addresses the last element along an axis. "wrap" means to wrap around.

source

# MXNet.mx.col2imMethod.

col2im(data, output_size, kernel, stride, dilate, pad)

Combining the output column matrix of im2col back to image array.

Like :class:~mxnet.ndarray.im2col, this operator is also used in the vanilla convolution implementation. Despite the name, col2im is not the reverse operation of im2col. Since there may be overlaps between neighbouring sliding blocks, the column elements cannot be directly put back into image. Instead, they are accumulated (i.e., summed) in the input image just like the gradient computation, so col2im is the gradient of im2col and vice versa.

Using the notation in im2col, given an input column array of shape :math:(N, C \times \prod(\text{kernel}), W), this operator accumulates the column elements into output array of shape :math:(N, C, \text{output_size}[0], \text{output_size}[1], \dots). Only 1-D, 2-D and 3-D of spatial dimension is supported in this operator.

Defined in src/operator/nn/im2col.cc:L181

Arguments

  • data::NDArray-or-SymbolicNode: Input array to combine sliding blocks.
  • output_size::Shape(tuple), required: The spatial dimension of image array: (w,), (h, w) or (d, h, w).
  • kernel::Shape(tuple), required: Sliding kernel size: (w,), (h, w) or (d, h, w).
  • stride::Shape(tuple), optional, default=[]: The stride between adjacent sliding blocks in spatial dimension: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: The spacing between adjacent kernel points: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: The zero-value padding size on both sides of spatial dimension: (w,), (h, w) or (d, h, w). Defaults to no padding.

source

# MXNet.mx.concatMethod.

concat(data, num_args, dim)

concat is an alias of Concat.

Note: concat takes variable number of positional inputs. So instead of calling as concat([x, y, z], numargs=3), one should call via concat(x, y, z), and numargs will be determined automatically.

Joins input arrays along a given axis.

.. note:: Concat is deprecated. Use concat instead.

The dimensions of the input arrays should be the same except the axis along which they will be concatenated. The dimension of the output array along the concatenated axis will be equal to the sum of the corresponding dimensions of the input arrays.

The storage type of $concat$ output depends on storage types of inputs

  • concat(csr, csr, ..., csr, dim=0) = csr
  • otherwise, $concat$ generates output with default storage

Example::

x = [[1,1],[2,2]] y = [[3,3],[4,4],[5,5]] z = [[6,6], [7,7],[8,8]]

concat(x,y,z,dim=0) = [[ 1., 1.], [ 2., 2.], [ 3., 3.], [ 4., 4.], [ 5., 5.], [ 6., 6.], [ 7., 7.], [ 8., 8.]]

Note that you cannot concat x,y,z along dimension 1 since dimension 0 is not the same for all the input arrays.

concat(y,z,dim=1) = [[ 3., 3., 6., 6.], [ 4., 4., 7., 7.], [ 5., 5., 8., 8.]]

Defined in src/operator/nn/concat.cc:L384

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to concatenate
  • num_args::int, required: Number of inputs to be concated.
  • dim::int, optional, default='1': the dimension to be concated.

source

# MXNet.mx.cropMethod.

crop(data, begin, end, step)

crop is an alias of slice.

Slices a region of the array. .. note:: $crop$ is deprecated. Use $slice$ instead. This function returns a sliced array between the indices given by begin and end with the corresponding step. For an input array of $shape=(d_0, d_1, ..., d_n-1)$, slice operation with $begin=(b_0, b_1...b_m-1)$, $end=(e_0, e_1, ..., e_m-1)$, and $step=(s_0, s_1, ..., s_m-1)$, where m <= n, results in an array with the shape $(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)$. The resulting array's k-th dimension contains elements from the k-th dimension of the input array starting from index $b_k$ (inclusive) with step $s_k$ until reaching $e_k$ (exclusive). If the k-th elements are None in the sequence of begin, end, and step, the following rule will be used to set default values. If s_k is None, set s_k=1. If s_k > 0, set b_k=0, e_k=d_k; else, set b_k=d_k-1, e_k=-1. The storage type of $slice$ output depends on storage types of inputs

  • slice(csr) = csr
  • otherwise, $slice$ generates output with default storage

.. note:: When input data storage type is csr, it only supports step=(), or step=(None,), or step=(1,) to generate a csr output. For other step parameter values, it falls back to slicing a dense tensor. Example:: x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] slice(x, begin=(0,1), end=(2,4)) = [[ 2., 3., 4.], [ 6., 7., 8.]] slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.], [5., 7.], [1., 3.]]

Defined in src/operator/tensor/matrix_op.cc:L481

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx.ctc_lossMethod.

ctc_loss(data, label, data_lengths, label_lengths, use_data_lengths, use_label_lengths, blank_label)

ctc_loss is an alias of CTCLoss.

Connectionist Temporal Classification Loss.

.. note:: The existing alias $contrib_CTCLoss$ is deprecated.

The shapes of the inputs and outputs:

  • data: (sequence_length, batch_size, alphabet_size)
  • label: (batch_size, label_sequence_length)
  • out: (batch_size)

The data tensor consists of sequences of activation vectors (without applying softmax), with i-th channel in the last dimension corresponding to i-th label for i between 0 and alphabet*size-1 (i.e always 0-indexed). Alphabet size should include one additional value reserved for blank label. When blank*labelis"first", the0-th channel is be reserved for activation of blank label, or otherwise if it is "last",(alphabet_size-1)-th channel should be reserved for blank label.

$label$ is an index matrix of integers. When blank_label is $"first"$, the value 0 is then reserved for blank label, and should not be passed in this matrix. Otherwise, when blank_label is $"last"$, the value (alphabet_size-1) is reserved for blank label.

If a sequence of labels is shorter than labelsequencelength, use the special padding value at the end of the sequence to conform it to the correct length. The padding value is 0 when blank_label is $"first"$, and -1 otherwise.

For example, suppose the vocabulary is [a, b, c], and in one batch we have three sequences 'ba', 'cbb', and 'abac'. When blank_label is $"first"$, we can index the labels as {'a': 1, 'b': 2, 'c': 3}, and we reserve the 0-th channel for blank label in data tensor. The resulting label tensor should be padded to be::

[[2, 1, 0, 0], [3, 2, 2, 0], [1, 2, 1, 3]]

When blank_label is $"last"$, we can index the labels as {'a': 0, 'b': 1, 'c': 2}, and we reserve the channel index 3 for blank label in data tensor. The resulting label tensor should be padded to be::

[[1, 0, -1, -1], [2, 1, 1, -1], [0, 1, 0, 2]]

$out$ is a list of CTC loss values, one per example in the batch.

See Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks, A. Graves et al. for more information on the definition and the algorithm.

Defined in src/operator/nn/ctc_loss.cc:L100

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • label::NDArray-or-SymbolicNode: Ground-truth labels for the loss.
  • data_lengths::NDArray-or-SymbolicNode: Lengths of data for each of the samples. Only required when usedatalengths is true.
  • label_lengths::NDArray-or-SymbolicNode: Lengths of labels for each of the samples. Only required when uselabellengths is true.
  • use_data_lengths::boolean, optional, default=0: Whether the data lenghts are decided by data_lengths. If false, the lengths are equal to the max sequence length.
  • use_label_lengths::boolean, optional, default=0: Whether the label lenghts are decided by label_lengths, or derived from padding_mask. If false, the lengths are derived from the first occurrence of the value of padding_mask. The value of padding_mask is $0$ when first CTC label is reserved for blank, and $-1$ when last label is reserved for blank. See blank_label.
  • blank_label::{'first', 'last'},optional, default='first': Set the label that is reserved for blank label.If "first", 0-th label is reserved, and label values for tokens in the vocabulary are between $1$ and $alphabet_size-1$, and the padding mask is $-1$. If "last", last label value $alphabet_size-1$ is reserved for blank label instead, and label values for tokens in the vocabulary are between $0$ and $alphabet_size-2$, and the padding mask is $0$.

source

# MXNet.mx.degreesMethod.

degrees(data)

Converts each element of the input array from radians to degrees.

.. math:: degrees([0, \pi/2, \pi, 3\pi/2, 2\pi]) = [0, 90, 180, 270, 360]

The storage type of $degrees$ output depends upon the input storage type:

  • degrees(default) = default
  • degrees(rowsparse) = rowsparse
  • degrees(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L332

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.depth_to_spaceMethod.

depth_to_space(data, block_size)

Rearranges(permutes) data from depth into blocks of spatial data. Similar to ONNX DepthToSpace operator: https://github.com/onnx/onnx/blob/master/docs/Operators.md#DepthToSpace. The output is a new tensor where the values from depth dimension are moved in spatial blocks to height and width dimension. The reverse of this operation is $space_to_depth$. .. math:: where :math:x is an input tensor with default layout as :math:[N, C, H, W]: [batch, channels, height, width] and :math:y is the output tensor of layout :math:[N, C / (block\_size ^ 2), H * block\_size, W * block\_size] Example:: x = [[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]], [[18, 19, 20], [21, 22, 23]]]] depthtospace(x, 2) = [[[[0, 6, 1, 7, 2, 8], [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], [15, 21, 16, 22, 17, 23]]]]

Defined in src/operator/tensor/matrix_op.cc:L971

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • block_size::int, required: Blocks of [blocksize. blocksize] are moved

source

# MXNet.mx.elemwise_addMethod.

elemwise_add(lhs, rhs)

Adds arguments element-wise.

The storage type of $elemwise_add$ output depends on storage types of inputs

  • elemwiseadd(rowsparse, rowsparse) = rowsparse
  • elemwise_add(csr, csr) = csr
  • elemwise_add(default, csr) = default
  • elemwise_add(csr, default) = default
  • elemwise_add(default, rsp) = default
  • elemwise_add(rsp, default) = default
  • otherwise, $elemwise_add$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx.elemwise_divMethod.

elemwise_div(lhs, rhs)

Divides arguments element-wise.

The storage type of $elemwise_div$ output is always dense

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx.elemwise_mulMethod.

elemwise_mul(lhs, rhs)

Multiplies arguments element-wise.

The storage type of $elemwise_mul$ output depends on storage types of inputs

  • elemwise_mul(default, default) = default
  • elemwisemul(rowsparse, rowsparse) = rowsparse
  • elemwisemul(default, rowsparse) = row_sparse
  • elemwisemul(rowsparse, default) = row_sparse
  • elemwise_mul(csr, csr) = csr
  • otherwise, $elemwise_mul$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx.elemwise_subMethod.

elemwise_sub(lhs, rhs)

Subtracts arguments element-wise.

The storage type of $elemwise_sub$ output depends on storage types of inputs

  • elemwisesub(rowsparse, rowsparse) = rowsparse
  • elemwise_sub(csr, csr) = csr
  • elemwise_sub(default, csr) = default
  • elemwise_sub(csr, default) = default
  • elemwise_sub(default, rsp) = default
  • elemwise_sub(rsp, default) = default
  • otherwise, $elemwise_sub$ generates output with default storage

Arguments

  • lhs::NDArray-or-SymbolicNode: first input
  • rhs::NDArray-or-SymbolicNode: second input

source

# MXNet.mx.erfMethod.

erf(data)

Returns element-wise gauss error function of the input.

Example::

erf([0, -1., 10.]) = [0., -0.8427, 1.]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L886

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.erfinvMethod.

erfinv(data)

Returns element-wise inverse gauss error function of the input.

Example::

erfinv([0, 0.5., -1.]) = [0., 0.4769, -inf]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L908

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.fill_element_0indexMethod.

fill_element_0index(lhs, mhs, rhs)

Fill one element of each line(row for python, column for R/Julia) in lhs according to index indicated by rhs and values indicated by mhs. This function assume rhs uses 0-based index.

Arguments

  • lhs::NDArray: Left operand to the function.
  • mhs::NDArray: Middle operand to the function.
  • rhs::NDArray: Right operand to the function.

source

# MXNet.mx.fixMethod.

fix(data)

Returns element-wise rounded value to the nearest integer towards zero of the input.

Example::

fix([-2.1, -1.9, 1.9, 2.1]) = [-2., -1., 1., 2.]

The storage type of $fix$ output depends upon the input storage type:

  • fix(default) = default
  • fix(rowsparse) = rowsparse
  • fix(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L874

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.flattenMethod.

flatten(data)

flatten is an alias of Flatten.

Flattens the input array into a 2-D array by collapsing the higher dimensions. .. note:: Flatten is deprecated. Use flatten instead. For an input array with shape $(d1, d2, ..., dk)$, flatten operation reshapes the input array into an output array of shape $(d1, d2...dk)$. Note that the behavior of this function is different from numpy.ndarray.flatten, which behaves similar to mxnet.ndarray.reshape((-1,)). Example:: x = [[ [1,2,3], [4,5,6], [7,8,9] ], [ [1,2,3], [4,5,6], [7,8,9] ]], flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]]

Defined in src/operator/tensor/matrix_op.cc:L249

Arguments

  • data::NDArray-or-SymbolicNode: Input array.

source

# MXNet.mx.flipMethod.

flip(data, axis)

flip is an alias of reverse.

Reverses the order of elements along given axis while preserving array shape. Note: reverse and flip are equivalent. We use reverse in the following examples. Examples:: x = [[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.]] reverse(x, axis=0) = [[ 5., 6., 7., 8., 9.], [ 0., 1., 2., 3., 4.]] reverse(x, axis=1) = [[ 4., 3., 2., 1., 0.], [ 9., 8., 7., 6., 5.]]

Defined in src/operator/tensor/matrix_op.cc:L831

Arguments

  • data::NDArray-or-SymbolicNode: Input data array
  • axis::Shape(tuple), required: The axis which to reverse elements.

source

# MXNet.mx.ftml_updateMethod.

ftml_update(weight, grad, d, v, z, lr, beta1, beta2, epsilon, t, wd, rescale_grad, clip_grad)

The FTML optimizer described in FTML - Follow the Moving Leader in Deep Learning, available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.

.. math::

gt = \nabla J(W)\ vt = \beta2 v{t-1} + (1 - \beta2) gt^2\ dt = \frac{ 1 - \beta1^t }{ \etat } (\sqrt{ \frac{ vt }{ 1 - \beta2^t } } + \epsilon) \sigmat = dt - \beta1 d zt = \beta1 z{ t-1 } + (1 - \beta1^t) gt - \sigmat W{t-1} Wt = - \frac{ zt }{ dt }

Defined in src/operator/optimizer_op.cc:L639

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • d::NDArray-or-SymbolicNode: Internal state $d_t$
  • v::NDArray-or-SymbolicNode: Internal state $v_t$
  • z::NDArray-or-SymbolicNode: Internal state $z_t$
  • lr::float, required: Learning rate.
  • beta1::float, optional, default=0.600000024: Generally close to 0.5.
  • beta2::float, optional, default=0.999000013: Generally close to 1.
  • epsilon::double, optional, default=9.9999999392252903e-09: Epsilon to prevent div 0.
  • t::int, required: Number of update.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_grad::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx.ftrl_updateMethod.

ftrl_update(weight, grad, z, n, lr, lamda1, beta, wd, rescale_grad, clip_gradient)

Update function for Ftrl optimizer. Referenced from Ad Click Prediction: a View from the Trenches, available at http://dl.acm.org/citation.cfm?id=2488200.

It updates the weights using::

rescaledgrad = clip(grad * rescalegrad, clipgradient) z += rescaledgrad - (sqrt(n + rescaledgrad2) - sqrt(n)) * weight / learningrate n += rescaledgrad2 w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learningrate + wd) * (abs(z) > lamda1)

If w, z and n are all of $row_sparse$ storage type, only the row slices whose indices appear in grad.indices are updated (for w, z and n)::

for row in grad.indices: rescaledgrad[row] = clip(grad[row] * rescalegrad, clipgradient) z[row] += rescaledgrad[row] - (sqrt(n[row] + rescaledgrad[row]2) - sqrt(n[row])) * weight[row] / learningrate n[row] += rescaledgrad[row]2 w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learningrate + wd) * (abs(z[row]) > lamda1)

Defined in src/operator/optimizer_op.cc:L875

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • z::NDArray-or-SymbolicNode: z
  • n::NDArray-or-SymbolicNode: Square of grad
  • lr::float, required: Learning rate
  • lamda1::float, optional, default=0.00999999978: The L1 regularization coefficient.
  • beta::float, optional, default=1: Per-Coordinate Learning Rate beta.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx.gammaMethod.

gamma(data)

Returns the gamma function (extension of the factorial function to the reals), computed element-wise on the input array.

The storage type of $gamma$ output is always dense

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.gammalnMethod.

gammaln(data)

Returns element-wise log of the absolute value of the gamma function of the input.

The storage type of $gammaln$ output is always dense

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.gather_ndMethod.

gather_nd(data, indices)

Gather elements or slices from data and store to a tensor whose shape is defined by indices.

Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with shape (M, Y_0, ..., Y_{K-1}), the output will have shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, output shape will simply be (Y_0, ..., Y_{K-1}).

The elements in output is defined as follows::

output[y0, ..., y, xM, ..., x] = data[indices[0, y0, ..., y], ..., indices[M-1, y0, ..., y], xM, ..., x]

Examples::

data = [[0, 1], [2, 3]] indices = [[1, 1, 0], [0, 1, 0]] gather_nd(data, indices) = [2, 3, 0]

data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] indices = [[0, 1], [1, 0]] gather_nd(data, indices) = [[3, 4], [5, 6]]

Arguments

  • data::NDArray-or-SymbolicNode: data
  • indices::NDArray-or-SymbolicNode: indices

source

# MXNet.mx.hard_sigmoidMethod.

hard_sigmoid(data, alpha, beta)

Computes hard sigmoid of x element-wise.

.. math:: y = max(0, min(1, alpha * x + beta))

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L161

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • alpha::float, optional, default=0.200000003: Slope of hard sigmoid
  • beta::float, optional, default=0.5: Bias of hard sigmoid.

source

# MXNet.mx.im2colMethod.

im2col(data, kernel, stride, dilate, pad)

Extract sliding blocks from input array.

This operator is used in vanilla convolution implementation to transform the sliding blocks on image to column matrix, then the convolution operation can be computed by matrix multiplication between column and convolution weight. Due to the close relation between im2col and convolution, the concept of kernel, stride, dilate and pad in this operator are inherited from convolution operation.

Given the input data of shape :math:(N, C, *), where :math:N is the batch size, :math:C is the channel size, and :math:* is the arbitrary spatial dimension, the output column array is always with shape :math:(N, C \times \prod(\text{kernel}), W), where :math:C \times \prod(\text{kernel}) is the block size, and :math:W is the block number which is the spatial size of the convolution output with same input parameters. Only 1-D, 2-D and 3-D of spatial dimension is supported in this operator.

Defined in src/operator/nn/im2col.cc:L99

Arguments

  • data::NDArray-or-SymbolicNode: Input array to extract sliding blocks.
  • kernel::Shape(tuple), required: Sliding kernel size: (w,), (h, w) or (d, h, w).
  • stride::Shape(tuple), optional, default=[]: The stride between adjacent sliding blocks in spatial dimension: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • dilate::Shape(tuple), optional, default=[]: The spacing between adjacent kernel points: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
  • pad::Shape(tuple), optional, default=[]: The zero-value padding size on both sides of spatial dimension: (w,), (h, w) or (d, h, w). Defaults to no padding.

source

# MXNet.mx.khatri_raoMethod.

khatri_rao(args)

Note: khatrirao takes variable number of positional inputs. So instead of calling as khatrirao([x, y, z], numargs=3), one should call via khatrirao(x, y, z), and num_args will be determined automatically.

Computes the Khatri-Rao product of the input matrices.

Given a collection of :math:n input matrices,

.. math:: A1 \in \mathbb{R}^{M1 \times M}, \ldots, An \in \mathbb{R}^{Mn \times N},

the (column-wise) Khatri-Rao product is defined as the matrix,

.. math:: X = A1 \otimes \cdots \otimes An \in \mathbb{R}^{(M1 \cdots Mn) \times N},

where the :math:k th column is equal to the column-wise outer product :math:{A_1}_k \otimes \cdots \otimes {A_n}_k where :math:{A_i}_k is the kth column of the ith matrix.

Example::

A = mx.nd.array([[1, -1], [2, -3]]) B = mx.nd.array([[1, 4], [2, 5], [3, 6]]) C = mx.nd.khatri_rao(A, B) print(C.asnumpy())

[[ 1. -4.] [ 2. -5.] [ 3. -6.] [ 2. -12.] [ 4. -15.] [ 6. -18.]]

Defined in src/operator/contrib/krprod.cc:L108

Arguments

  • args::NDArray-or-SymbolicNode[]: Positional input matrices

source

# MXNet.mx.lamb_update_phase1Method.

lamb_update_phase1(weight, grad, mean, var, beta1, beta2, epsilon, t, bias_correction, wd, rescale_grad, clip_gradient)

Phase I of lamb update it performs the following operations and returns g:.

Link to paper: https://arxiv.org/pdf/1904.00962.pdf

.. math:: \begin{gather} grad = grad * rescalegrad if (grad < -clipgradient) then grad = -clipgradient if (grad > clip*gradient) then grad = clip_gradient

mean = beta1 * mean + (1 - beta1) * grad;
variance = beta2 * variance + (1. - beta2) * grad ^ 2;

if (bias_correction)
then
     mean_hat = mean / (1. - beta1^t);
     var_hat = var / (1 - beta2^t);
     g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
else
     g = mean / (var_data^(1/2) + epsilon) + wd * weight;
\end{gather*}

Defined in src/operator/optimizer_op.cc:L952

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mean::NDArray-or-SymbolicNode: Moving mean
  • var::NDArray-or-SymbolicNode: Moving variance
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999997e-07: A small constant for numerical stability.
  • t::int, required: Index update count.
  • bias_correction::boolean, optional, default=1: Whether to use bias correction.
  • wd::float, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx.lamb_update_phase2Method.

lamb_update_phase2(weight, g, r1, r2, lr, lower_bound, upper_bound)

Phase II of lamb update it performs the following operations and updates grad.

Link to paper: https://arxiv.org/pdf/1904.00962.pdf

.. math:: \begin{gather} if (lowerbound >= 0) then r1 = max(r1, lowerbound) if (upperbound >= 0) then r1 = max(r1, upper*bound)

if (r1 == 0 or r2 == 0)
then
     lr = lr
else
     lr = lr * (r1/r2)
weight = weight - lr * g
\end{gather*}

Defined in src/operator/optimizer_op.cc:L991

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • g::NDArray-or-SymbolicNode: Output of lambupdatephase 1
  • r1::NDArray-or-SymbolicNode: r1
  • r2::NDArray-or-SymbolicNode: r2
  • lr::float, required: Learning rate
  • lower_bound::float, optional, default=-1: Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set
  • upper_bound::float, optional, default=-1: Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set

source

# MXNet.mx.linalg_detMethod.

linalg_det(A)

linalgdet is an alias of _linalgdet.

Compute the determinant of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

out = det(A)

If n>2, det is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only. .. note:: There is no gradient backwarded when A is non-invertible (which is equivalent to det(A) = 0) because zero is rarely hit upon in float point computation and the Jacobi's formula on determinant gradient is not computationally efficient when A is non-invertible.

Examples::

Single matrix determinant A = [[1., 4.], [2., 3.]] det(A) = [-5.]

Batch matrix determinant A = [[[1., 4.], [2., 3.]], [[2., 3.], [1., 4.]]] det(A) = [-5., 5.]

Defined in src/operator/tensor/la_op.cc:L974

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx.linalg_extractdiagMethod.

linalg_extractdiag(A, offset)

linalgextractdiag is an alias of _linalgextractdiag.

Extracts the diagonal entries of a square matrix. Input is a tensor A of dimension n >= 2.

If n=2, then A represents a single square matrix which diagonal elements get extracted as a 1-dimensional tensor.

If n>2, then A represents a batch of square matrices on the trailing two dimensions. The extracted diagonals are returned as an n-1-dimensional tensor.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix diagonal extraction
A = [[1.0, 2.0],
     [3.0, 4.0]]

extractdiag(A) = [1.0, 4.0]

extractdiag(A, 1) = [2.0]

Batch matrix diagonal extraction
A = [[[1.0, 2.0],
      [3.0, 4.0]],
     [[5.0, 6.0],
      [7.0, 8.0]]]

extractdiag(A) = [[1.0, 4.0],
                  [5.0, 8.0]]

Defined in src/operator/tensor/la_op.cc:L494

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrices
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.

source

# MXNet.mx.linalg_extracttrianMethod.

linalg_extracttrian(A, offset, lower)

linalgextracttrian is an alias of _linalgextracttrian.

Extracts a triangular sub-matrix from a square matrix. Input is a tensor A of dimension n >= 2.

If n=2, then A represents a single square matrix from which a triangular sub-matrix is extracted as a 1-dimensional tensor.

If n>2, then A represents a batch of square matrices on the trailing two dimensions. The extracted triangular sub-matrices are returned as an n-1-dimensional tensor.

The offset and lower parameters determine the triangle to be extracted:

  • When offset = 0 either the lower or upper triangle with respect to the main diagonal is extracted depending on the value of parameter lower.
  • When offset = k > 0 the upper triangle with respect to the k-th diagonal above the main diagonal is extracted.
  • When offset = k < 0 the lower triangle with respect to the k-th diagonal below the main diagonal is extracted.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single triagonal extraction
A = [[1.0, 2.0],
     [3.0, 4.0]]

extracttrian(A) = [1.0, 3.0, 4.0]
extracttrian(A, lower=False) = [1.0, 2.0, 4.0]
extracttrian(A, 1) = [2.0]
extracttrian(A, -1) = [3.0]

Batch triagonal extraction
A = [[[1.0, 2.0],
      [3.0, 4.0]],
     [[5.0, 6.0],
      [7.0, 8.0]]]

extracttrian(A) = [[1.0, 3.0, 4.0],
                   [5.0, 7.0, 8.0]]

Defined in src/operator/tensor/la_op.cc:L604

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrices
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.
  • lower::boolean, optional, default=1: Refer to the lower triangular matrix if lower=true, refer to the upper otherwise. Only relevant when offset=0

source

# MXNet.mx.linalg_gelqfMethod.

linalg_gelqf(A)

linalggelqf is an alias of _linalggelqf.

LQ factorization for general matrix. Input is a tensor A of dimension n >= 2.

If n=2, we compute the LQ factorization (LAPACK gelqf, followed by orglq). A must have shape (x, y) with x <= y, and must have full rank =x. The LQ factorization consists of L with shape (x, x) and Q with shape (x, y), so that:

A = L * Q

Here, L is lower triangular (upper triangle equal to zero) with nonzero diagonal, and Q is row-orthonormal, meaning that

Q * Q\ :sup:T

is equal to the identity matrix of shape (x, x).

If n>2, gelqf is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single LQ factorization A = [[1., 2., 3.], [4., 5., 6.]] Q, L = gelqf(A) Q = [[-0.26726124, -0.53452248, -0.80178373], [0.87287156, 0.21821789, -0.43643578]] L = [[-3.74165739, 0.], [-8.55235974, 1.96396101]]

Batch LQ factorization A = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]] Q, L = gelqf(A) Q = [[[-0.26726124, -0.53452248, -0.80178373], [0.87287156, 0.21821789, -0.43643578]], [[-0.50257071, -0.57436653, -0.64616234], [0.7620735, 0.05862104, -0.64483142]]] L = [[[-3.74165739, 0.], [-8.55235974, 1.96396101]], [[-13.92838828, 0.], [-19.09768702, 0.52758934]]]

Defined in src/operator/tensor/la_op.cc:L797

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices to be factorized

source

# MXNet.mx.linalg_gemmMethod.

linalg_gemm(A, B, C, transpose_a, transpose_b, alpha, beta, axis)

linalggemm is an alias of _linalggemm.

Performs general matrix multiplication and accumulation. Input are tensors A, B, C, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, the BLAS3 function gemm is performed:

out = alpha * op\ (A) * op\ (B) + beta * C

Here, alpha and beta are scalar parameters, and op() is either the identity or matrix transposition (depending on transpose_a, transpose_b).

If n>2, gemm is performed separately for a batch of matrices. The column indices of the matrices are given by the last dimensions of the tensors, the row indices by the axis specified with the axis parameter. By default, the trailing two dimensions will be used for matrix encoding.

For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes calls. For example let A, B, C be 5 dimensional tensors. Then gemm(A, B, C, axis=1) is equivalent to the following without the overhead of the additional swapaxis operations::

A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = swapaxes(C, dim1=1, dim2=3)
C = gemm(A1, B1, C)
C = swapaxis(C, dim1=1, dim2=3)

When the input data is of type float32 and the environment variables MXNETCUDAALLOWTENSORCORE and MXNETCUDATENSOROPMATHALLOWCONVERSION are set to 1, this operator will try to use pseudo-float16 precision (float32 math with float16 I/O) precision in order to use Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix multiply-add A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]]

Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0) = [[[104.0]], [[0.14]]]

Defined in src/operator/tensor/la_op.cc:L88

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices
  • B::NDArray-or-SymbolicNode: Tensor of input matrices
  • C::NDArray-or-SymbolicNode: Tensor of input matrices
  • transpose_a::boolean, optional, default=0: Multiply with transposed of first input (A).
  • transpose_b::boolean, optional, default=0: Multiply with transposed of second input (B).
  • alpha::double, optional, default=1: Scalar factor multiplied with A*B.
  • beta::double, optional, default=1: Scalar factor multiplied with C.
  • axis::int, optional, default='-2': Axis corresponding to the matrix rows.

source

# MXNet.mx.linalg_gemm2Method.

linalg_gemm2(A, B, transpose_a, transpose_b, alpha, axis)

linalggemm2 is an alias of _linalggemm2.

Performs general matrix multiplication. Input are tensors A, B, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, the BLAS3 function gemm is performed:

out = alpha * op\ (A) * op\ (B)

Here alpha is a scalar parameter and op() is either the identity or the matrix transposition (depending on transpose_a, transpose_b).

If n>2, gemm is performed separately for a batch of matrices. The column indices of the matrices are given by the last dimensions of the tensors, the row indices by the axis specified with the axis parameter. By default, the trailing two dimensions will be used for matrix encoding.

For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes calls. For example let A, B be 5 dimensional tensors. Then gemm(A, B, axis=1) is equivalent to the following without the overhead of the additional swapaxis operations::

A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = gemm2(A1, B1)
C = swapaxis(C, dim1=1, dim2=3)

When the input data is of type float32 and the environment variables MXNETCUDAALLOWTENSORCORE and MXNETCUDATENSOROPMATHALLOWCONVERSION are set to 1, this operator will try to use pseudo-float16 precision (float32 math with float16 I/O) precision in order to use Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]

Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[[4.0]], [[0.04 ]]]

Defined in src/operator/tensor/la_op.cc:L162

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices
  • B::NDArray-or-SymbolicNode: Tensor of input matrices
  • transpose_a::boolean, optional, default=0: Multiply with transposed of first input (A).
  • transpose_b::boolean, optional, default=0: Multiply with transposed of second input (B).
  • alpha::double, optional, default=1: Scalar factor multiplied with A*B.
  • axis::int, optional, default='-2': Axis corresponding to the matrix row indices.

source

# MXNet.mx.linalg_inverseMethod.

linalg_inverse(A)

linalginverse is an alias of _linalginverse.

Compute the inverse of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

out = A\ :sup:-1

If n>2, inverse is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix inverse A = [[1., 4.], [2., 3.]] inverse(A) = [[-0.6, 0.8], [0.4, -0.2]]

Batch matrix inverse A = [[[1., 4.], [2., 3.]], [[1., 3.], [2., 4.]]] inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]], [[-2., 1.5], [1., -0.5]]]

Defined in src/operator/tensor/la_op.cc:L919

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx.linalg_makediagMethod.

linalg_makediag(A, offset)

linalgmakediag is an alias of _linalgmakediag.

Constructs a square matrix with the input as diagonal. Input is a tensor A of dimension n >= 1.

If n=1, then A represents the diagonal entries of a single square matrix. This matrix will be returned as a 2-dimensional tensor. If n>1, then A represents a batch of diagonals of square matrices. The batch of diagonal matrices will be returned as an n+1-dimensional tensor.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single diagonal matrix construction
A = [1.0, 2.0]

makediag(A)    = [[1.0, 0.0],
                  [0.0, 2.0]]

makediag(A, 1) = [[0.0, 1.0, 0.0],
                  [0.0, 0.0, 2.0],
                  [0.0, 0.0, 0.0]]

Batch diagonal matrix construction
A = [[1.0, 2.0],
     [3.0, 4.0]]

makediag(A) = [[[1.0, 0.0],
                [0.0, 2.0]],
               [[3.0, 0.0],
                [0.0, 4.0]]]

Defined in src/operator/tensor/la_op.cc:L546

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of diagonal entries
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.

source

# MXNet.mx.linalg_maketrianMethod.

linalg_maketrian(A, offset, lower)

linalgmaketrian is an alias of _linalgmaketrian.

Constructs a square matrix with the input representing a specific triangular sub-matrix. This is basically the inverse of linalg.extracttrian. Input is a tensor A of dimension n >= 1.

If n=1, then A represents the entries of a triangular matrix which is lower triangular if offset<0 or offset=0, lower=true. The resulting matrix is derived by first constructing the square matrix with the entries outside the triangle set to zero and then adding offset-times an additional diagonal with zero entries to the square matrix.

If n>1, then A represents a batch of triangular sub-matrices. The batch of corresponding square matrices is returned as an n+1-dimensional tensor.

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single  matrix construction
A = [1.0, 2.0, 3.0]

maketrian(A)              = [[1.0, 0.0],
                             [2.0, 3.0]]

maketrian(A, lower=false) = [[1.0, 2.0],
                             [0.0, 3.0]]

maketrian(A, offset=1)    = [[0.0, 1.0, 2.0],
                             [0.0, 0.0, 3.0],
                             [0.0, 0.0, 0.0]]
maketrian(A, offset=-1)   = [[0.0, 0.0, 0.0],
                             [1.0, 0.0, 0.0],
                             [2.0, 3.0, 0.0]]

Batch matrix construction
A = [[1.0, 2.0, 3.0],
     [4.0, 5.0, 6.0]]

maketrian(A)           = [[[1.0, 0.0],
                           [2.0, 3.0]],
                          [[4.0, 0.0],
                           [5.0, 6.0]]]

maketrian(A, offset=1) = [[[0.0, 1.0, 2.0],
                           [0.0, 0.0, 3.0],
                           [0.0, 0.0, 0.0]],
                          [[0.0, 4.0, 5.0],
                           [0.0, 0.0, 6.0],
                           [0.0, 0.0, 0.0]]]

Defined in src/operator/tensor/la_op.cc:L672

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of triangular matrices stored as vectors
  • offset::int, optional, default='0': Offset of the diagonal versus the main diagonal. 0 corresponds to the main diagonal, a negative/positive value to diagonals below/above the main diagonal.
  • lower::boolean, optional, default=1: Refer to the lower triangular matrix if lower=true, refer to the upper otherwise. Only relevant when offset=0

source

# MXNet.mx.linalg_potrfMethod.

linalg_potrf(A)

linalgpotrf is an alias of _linalgpotrf.

Performs Cholesky factorization of a symmetric positive-definite matrix. Input is a tensor A of dimension n >= 2.

If n=2, the Cholesky factor B of the symmetric, positive definite matrix A is computed. B is triangular (entries of upper or lower triangle are all zero), has positive diagonal entries, and:

A = B * B\ :sup:T if lower = true A = B\ :sup:T * B if lower = false

If n>2, potrf is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix factorization A = [[4.0, 1.0], [1.0, 4.25]] potrf(A) = [[2.0, 0], [0.5, 2.0]]

Batch matrix factorization A = [[[4.0, 1.0], [1.0, 4.25]], [[16.0, 4.0], [4.0, 17.0]]] potrf(A) = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]]

Defined in src/operator/tensor/la_op.cc:L213

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices to be decomposed

source

# MXNet.mx.linalg_potriMethod.

linalg_potri(A)

linalgpotri is an alias of _linalgpotri.

Performs matrix inversion from a Cholesky factorization. Input is a tensor A of dimension n >= 2.

If n=2, A is a triangular matrix (entries of upper or lower triangle are all zero) with positive diagonal. We compute:

out = A\ :sup:-T * A\ :sup:-1 if lower = true out = A\ :sup:-1 * A\ :sup:-T if lower = false

In other words, if A is the Cholesky factor of a symmetric positive definite matrix B (obtained by potrf), then

out = B\ :sup:-1

If n>2, potri is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

.. note:: Use this operator only if you are certain you need the inverse of B, and cannot use the Cholesky factor A (potrf), together with backsubstitution (trsm). The latter is numerically much safer, and also cheaper.

Examples::

Single matrix inverse A = [[2.0, 0], [0.5, 2.0]] potri(A) = [[0.26563, -0.0625], [-0.0625, 0.25]]

Batch matrix inverse A = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] potri(A) = [[[0.26563, -0.0625], [-0.0625, 0.25]], [[0.06641, -0.01562], [-0.01562, 0,0625]]]

Defined in src/operator/tensor/la_op.cc:L274

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of lower triangular matrices

source

# MXNet.mx.linalg_slogdetMethod.

linalg_slogdet(A)

linalgslogdet is an alias of _linalgslogdet.

Compute the sign and log of the determinant of a matrix. Input is a tensor A of dimension n >= 2.

If n=2, A is a square matrix. We compute:

sign = sign(det(A)) logabsdet = log(abs(det(A)))

If n>2, slogdet is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only. .. note:: The gradient is not properly defined on sign, so the gradient of it is not backwarded. .. note:: No gradient is backwarded when A is non-invertible. Please see the docs of operator det for detail.

Examples::

Single matrix signed log determinant A = [[2., 3.], [1., 4.]] sign, logabsdet = slogdet(A) sign = [1.] logabsdet = [1.609438]

Batch matrix signed log determinant A = [[[2., 3.], [1., 4.]], [[1., 2.], [2., 4.]], [[1., 2.], [4., 3.]]] sign, logabsdet = slogdet(A) sign = [1., 0., -1.] logabsdet = [1.609438, -inf, 1.609438]

Defined in src/operator/tensor/la_op.cc:L1033

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrix

source

# MXNet.mx.linalg_sumlogdiagMethod.

linalg_sumlogdiag(A)

linalgsumlogdiag is an alias of _linalgsumlogdiag.

Computes the sum of the logarithms of the diagonal elements of a square matrix. Input is a tensor A of dimension n >= 2.

If n=2, A must be square with positive diagonal entries. We sum the natural logarithms of the diagonal elements, the result has shape (1,).

If n>2, sumlogdiag is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix reduction A = [[1.0, 1.0], [1.0, 7.0]] sumlogdiag(A) = [1.9459]

Batch matrix reduction A = [[[1.0, 1.0], [1.0, 7.0]], [[3.0, 0], [0, 17.0]]] sumlogdiag(A) = [1.9459, 3.9318]

Defined in src/operator/tensor/la_op.cc:L444

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of square matrices

source

# MXNet.mx.linalg_syrkMethod.

linalg_syrk(A, transpose, alpha)

linalgsyrk is an alias of _linalgsyrk.

Multiplication of matrix with its transpose. Input is a tensor A of dimension n >= 2.

If n=2, the operator performs the BLAS3 function syrk:

out = alpha * A * A\ :sup:T

if transpose=False, or

out = alpha * A\ :sup:T \ * A

if transpose=True.

If n>2, syrk is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix multiply A = [[1., 2., 3.], [4., 5., 6.]] syrk(A, alpha=1., transpose=False) = [[14., 32.], [32., 77.]] syrk(A, alpha=1., transpose=True) = [[17., 22., 27.], [22., 29., 36.], [27., 36., 45.]]

Batch matrix multiply A = [[[1., 1.]], [[0.1, 0.1]]] syrk(A, alpha=2., transpose=False) = [[[4.]], [[0.04]]]

Defined in src/operator/tensor/la_op.cc:L729

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of input matrices
  • transpose::boolean, optional, default=0: Use transpose of input matrix.
  • alpha::double, optional, default=1: Scalar factor to be applied to the result.

source

# MXNet.mx.linalg_trmmMethod.

linalg_trmm(A, B, transpose, rightside, lower, alpha)

linalgtrmm is an alias of _linalgtrmm.

Performs multiplication with a lower triangular matrix. Input are tensors A, B, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, A must be triangular. The operator performs the BLAS3 function trmm:

out = alpha * op\ (A) * B

if rightside=False, or

out = alpha * B * op\ (A)

if rightside=True. Here, alpha is a scalar parameter, and op() is either the identity or the matrix transposition (depending on transpose).

If n>2, trmm is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single triangular matrix multiply A = [[1.0, 0], [1.0, 1.0]] B = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] trmm(A, B, alpha=2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]

Batch triangular matrix multiply A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]] trmm(A, B, alpha=2.0) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]

Defined in src/operator/tensor/la_op.cc:L332

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of lower triangular matrices
  • B::NDArray-or-SymbolicNode: Tensor of matrices
  • transpose::boolean, optional, default=0: Use transposed of the triangular matrix
  • rightside::boolean, optional, default=0: Multiply triangular matrix from the right to non-triangular one.
  • lower::boolean, optional, default=1: True if the triangular matrix is lower triangular, false if it is upper triangular.
  • alpha::double, optional, default=1: Scalar factor to be applied to the result.

source

# MXNet.mx.linalg_trsmMethod.

linalg_trsm(A, B, transpose, rightside, lower, alpha)

linalgtrsm is an alias of _linalgtrsm.

Solves matrix equation involving a lower triangular matrix. Input are tensors A, B, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, A must be triangular. The operator performs the BLAS3 function trsm, solving for out in:

op\ (A) * out = alpha * B

if rightside=False, or

out * op\ (A) = alpha * B

if rightside=True. Here, alpha is a scalar parameter, and op() is either the identity or the matrix transposition (depending on transpose).

If n>2, trsm is performed separately on the trailing two dimensions for all inputs (batch mode).

.. note:: The operator supports float32 and float64 data types only.

Examples::

Single matrix solve A = [[1.0, 0], [1.0, 1.0]] B = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] trsm(A, B, alpha=0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]

Batch matrix solve A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[4.0, 4.0, 4.0], [8.0, 8.0, 8.0]]] trsm(A, B, alpha=0.5) = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]

Defined in src/operator/tensor/la_op.cc:L395

Arguments

  • A::NDArray-or-SymbolicNode: Tensor of lower triangular matrices
  • B::NDArray-or-SymbolicNode: Tensor of matrices
  • transpose::boolean, optional, default=0: Use transposed of the triangular matrix
  • rightside::boolean, optional, default=0: Multiply triangular matrix from the right to non-triangular one.
  • lower::boolean, optional, default=1: True if the triangular matrix is lower triangular, false if it is upper triangular.
  • alpha::double, optional, default=1: Scalar factor to be applied to the result.

source

# MXNet.mx.logical_notMethod.

logical_not(data)

Returns the result of logical NOT (!) function

Example: logical_not([-2., 0., 1.]) = [0., 1., 0.]

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.make_lossMethod.

make_loss(data)

Make your own loss function in network construction.

This operator accepts a customized loss function symbol as a terminal loss and the symbol should be an operator with no backward dependency. The output of this function is the gradient of loss with respect to the input data.

For example, if you are a making a cross entropy loss function. Assume $out$ is the predicted output and $label$ is the true label, then the cross entropy can be defined as::

crossentropy = label * log(out) + (1 - label) * log(1 - out) loss = makeloss(cross_entropy)

We will need to use $make_loss$ when we are creating our own loss function or we want to combine multiple loss functions. Also we may want to stop some variables' gradients from backpropagation. See more detail in $BlockGrad$ or $stop_gradient$.

The storage type of $make_loss$ output depends upon the input storage type:

  • make_loss(default) = default
  • makeloss(rowsparse) = row_sparse

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L358

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.momentsMethod.

moments(data, axes, keepdims)

Calculate the mean and variance of data.

The mean and variance are calculated by aggregating the contents of data across axes. If x is 1-D and axes = [0] this is just the mean and variance of a vector.

Example:

 x = [[1, 2, 3], [4, 5, 6]]
 mean, var = moments(data=x, axes=[0])
 mean = [2.5, 3.5, 4.5]
 var = [2.25, 2.25, 2.25]
 mean, var = moments(data=x, axes=[1])
 mean = [2.0, 5.0]
 var = [0.66666667, 0.66666667]
 mean, var = moments(data=x, axis=[0, 1])
 mean = [3.5]
 var = [2.9166667]

Defined in src/operator/nn/moments.cc:L53

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • axes::Shape or None, optional, default=None: Array of ints. Axes along which to compute mean and variance.
  • keepdims::boolean, optional, default=0: produce moments with the same dimensionality as the input.

source

# MXNet.mx.mp_lamb_update_phase1Method.

mp_lamb_update_phase1(weight, grad, mean, var, weight32, beta1, beta2, epsilon, t, bias_correction, wd, rescale_grad, clip_gradient)

Mixed Precision version of Phase I of lamb update it performs the following operations and returns g:.

      Link to paper: https://arxiv.org/pdf/1904.00962.pdf

      .. math::
          \begin{gather*}
          grad32 = grad(float16) * rescale_grad
          if (grad < -clip_gradient)
          then
               grad = -clip_gradient
          if (grad > clip_gradient)
          then
               grad = clip_gradient

          mean = beta1 * mean + (1 - beta1) * grad;
          variance = beta2 * variance + (1. - beta2) * grad ^ 2;

          if (bias_correction)
          then
               mean_hat = mean / (1. - beta1^t);
               var_hat = var / (1 - beta2^t);
               g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight32;
          else
               g = mean / (var_data^(1/2) + epsilon) + wd * weight32;
          \end{gather*}

Defined in src/operator/optimizer_op.cc:L1032

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mean::NDArray-or-SymbolicNode: Moving mean
  • var::NDArray-or-SymbolicNode: Moving variance
  • weight32::NDArray-or-SymbolicNode: Weight32
  • beta1::float, optional, default=0.899999976: The decay rate for the 1st moment estimates.
  • beta2::float, optional, default=0.999000013: The decay rate for the 2nd moment estimates.
  • epsilon::float, optional, default=9.99999997e-07: A small constant for numerical stability.
  • t::int, required: Index update count.
  • bias_correction::boolean, optional, default=1: Whether to use bias correction.
  • wd::float, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx.mp_lamb_update_phase2Method.

mp_lamb_update_phase2(weight, g, r1, r2, weight32, lr, lower_bound, upper_bound)

Mixed Precision version Phase II of lamb update it performs the following operations and updates grad.

      Link to paper: https://arxiv.org/pdf/1904.00962.pdf

      .. math::
          \begin{gather*}
          if (lower_bound >= 0)
          then
               r1 = max(r1, lower_bound)
          if (upper_bound >= 0)
          then
               r1 = max(r1, upper_bound)

          if (r1 == 0 or r2 == 0)
          then
               lr = lr
          else
               lr = lr * (r1/r2)
          weight32 = weight32 - lr * g
          weight(float16) = weight32
          \end{gather*}

Defined in src/operator/optimizer_op.cc:L1074

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • g::NDArray-or-SymbolicNode: Output of mplambupdate_phase 1
  • r1::NDArray-or-SymbolicNode: r1
  • r2::NDArray-or-SymbolicNode: r2
  • weight32::NDArray-or-SymbolicNode: Weight32
  • lr::float, required: Learning rate
  • lower_bound::float, optional, default=-1: Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set
  • upper_bound::float, optional, default=-1: Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set

source

# MXNet.mx.mp_nag_mom_updateMethod.

mp_nag_mom_update(weight, grad, mom, weight32, lr, momentum, wd, rescale_grad, clip_gradient)

Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.

Defined in src/operator/optimizer_op.cc:L744

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mom::NDArray-or-SymbolicNode: Momentum
  • weight32::NDArray-or-SymbolicNode: Weight32
  • lr::float, required: Learning rate
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx.mp_sgd_mom_updateMethod.

mp_sgd_mom_update(weight, grad, mom, weight32, lr, momentum, wd, rescale_grad, clip_gradient, lazy_update)

Updater function for multi-precision sgd optimizer

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mom::NDArray-or-SymbolicNode: Momentum
  • weight32::NDArray-or-SymbolicNode: Weight32
  • lr::float, required: Learning rate
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse and both weight and momentum have the same stype

source

# MXNet.mx.mp_sgd_updateMethod.

mp_sgd_update(weight, grad, weight32, lr, wd, rescale_grad, clip_gradient, lazy_update)

Updater function for multi-precision sgd optimizer

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: gradient
  • weight32::NDArray-or-SymbolicNode: Weight32
  • lr::float, required: Learning rate
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse.

source

# MXNet.mx.multi_all_finiteMethod.

multi_all_finite(data, num_arrays, init_output)

Check if all the float numbers in all the arrays are finite (used for AMP)

Defined in src/operator/contrib/all_finite.cc:L132

Arguments

  • data::NDArray-or-SymbolicNode[]: Arrays
  • num_arrays::int, optional, default='1': Number of arrays.
  • init_output::boolean, optional, default=1: Initialize output to 1.

source

# MXNet.mx.multi_larsMethod.

multi_lars(lrs, weights_sum_sq, grads_sum_sq, wds, eta, eps, rescale_grad)

Compute the LARS coefficients of multiple weights and grads from their sums of square"

Defined in src/operator/contrib/multi_lars.cc:L36

Arguments

  • lrs::NDArray-or-SymbolicNode: Learning rates to scale by LARS coefficient
  • weights_sum_sq::NDArray-or-SymbolicNode: sum of square of weights arrays
  • grads_sum_sq::NDArray-or-SymbolicNode: sum of square of gradients arrays
  • wds::NDArray-or-SymbolicNode: weight decays
  • eta::float, required: LARS eta
  • eps::float, required: LARS eps
  • rescale_grad::float, optional, default=1: Gradient rescaling factor

source

# MXNet.mx.multi_mp_sgd_mom_updateMethod.

multi_mp_sgd_mom_update(data, lrs, wds, momentum, rescale_grad, clip_gradient, num_weights)

Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer.

Momentum update has better convergence rates on neural networks. Mathematically it looks like below:

.. math::

v1 = \alpha * \nabla J(W0)\ vt = \gamma v - \alpha * \nabla J(W{t-1})\ Wt = W{t-1} + vt

It updates the weights using::

v = momentum * v - learning_rate * gradient weight += v

Where the parameter $momentum$ is the decay rate of momentum estimates at each epoch.

Defined in src/operator/optimizer_op.cc:L471

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights
  • lrs::tuple of <float>, required: Learning rates.
  • wds::tuple of <float>, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.multi_mp_sgd_updateMethod.

multi_mp_sgd_update(data, lrs, wds, rescale_grad, clip_gradient, num_weights)

Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer.

It updates the weights using::

weight = weight - learning_rate * (gradient + wd * weight)

Defined in src/operator/optimizer_op.cc:L416

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights
  • lrs::tuple of <float>, required: Learning rates.
  • wds::tuple of <float>, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.multi_sgd_mom_updateMethod.

multi_sgd_mom_update(data, lrs, wds, momentum, rescale_grad, clip_gradient, num_weights)

Momentum update function for Stochastic Gradient Descent (SGD) optimizer.

Momentum update has better convergence rates on neural networks. Mathematically it looks like below:

.. math::

v1 = \alpha * \nabla J(W0)\ vt = \gamma v - \alpha * \nabla J(W{t-1})\ Wt = W{t-1} + vt

It updates the weights using::

v = momentum * v - learning_rate * gradient weight += v

Where the parameter $momentum$ is the decay rate of momentum estimates at each epoch.

Defined in src/operator/optimizer_op.cc:L373

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights, gradients and momentum
  • lrs::tuple of <float>, required: Learning rates.
  • wds::tuple of <float>, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.multi_sgd_updateMethod.

multi_sgd_update(data, lrs, wds, rescale_grad, clip_gradient, num_weights)

Update function for Stochastic Gradient Descent (SDG) optimizer.

It updates the weights using::

weight = weight - learning_rate * (gradient + wd * weight)

Defined in src/operator/optimizer_op.cc:L328

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights
  • lrs::tuple of <float>, required: Learning rates.
  • wds::tuple of <float>, required: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.multi_sum_sqMethod.

multi_sum_sq(data, num_arrays)

Compute the sums of squares of multiple arrays

Defined in src/operator/contrib/multisumsq.cc:L35

Arguments

  • data::NDArray-or-SymbolicNode[]: Arrays
  • num_arrays::int, required: number of input arrays.

source

# MXNet.mx.nag_mom_updateMethod.

nag_mom_update(weight, grad, mom, lr, momentum, wd, rescale_grad, clip_gradient)

Update function for Nesterov Accelerated Gradient( NAG) optimizer. It updates the weights using the following formula,

.. math:: vt = \gamma v + \eta * \nabla J(W{t-1} - \gamma v)\ Wt = W - v_t

Where :math:\eta is the learning rate of the optimizer :math:\gamma is the decay rate of the momentum estimate :math:\v_t is the update vector at time step t :math:\W_t is the weight vector at time step t

Defined in src/operator/optimizer_op.cc:L725

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mom::NDArray-or-SymbolicNode: Momentum
  • lr::float, required: Learning rate
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx.nanprodMethod.

nanprod(data, axis, keepdims, exclude)

Computes the product of array elements over given axes treating Not a Numbers ($NaN$) as one.

Defined in src/operator/tensor/broadcastreduceprod_value.cc:L46

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction.

    `` The default,axis=(), will compute over all elements into a scalar array with shape(1,)`.

    If axis is int, a reduction is performed on a particular axis.

    If axis is a tuple of ints, a reduction is performed on all the axes specified in the tuple.

    If exclude is true, reduction will be performed on the axes that are NOT in axis instead.

    Negative values means indexing from right to left. `` *keepdims::boolean, optional, default=0: If this is set toTrue, the reduced axes are left in the result as dimension with size one. *exclude::boolean, optional, default=0`: Whether to perform reduction on axis that are NOT in axis instead.

source

# MXNet.mx.nansumMethod.

nansum(data, axis, keepdims, exclude)

Computes the sum of array elements over given axes treating Not a Numbers ($NaN$) as zero.

Defined in src/operator/tensor/broadcastreducesum_value.cc:L101

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction.

    `` The default,axis=(), will compute over all elements into a scalar array with shape(1,)`.

    If axis is int, a reduction is performed on a particular axis.

    If axis is a tuple of ints, a reduction is performed on all the axes specified in the tuple.

    If exclude is true, reduction will be performed on the axes that are NOT in axis instead.

    Negative values means indexing from right to left. `` *keepdims::boolean, optional, default=0: If this is set toTrue, the reduced axes are left in the result as dimension with size one. *exclude::boolean, optional, default=0`: Whether to perform reduction on axis that are NOT in axis instead.

source

# MXNet.mx.negativeMethod.

negative(data)

Numerical negative of the argument, element-wise.

The storage type of $negative$ output depends upon the input storage type:

  • negative(default) = default
  • negative(rowsparse) = rowsparse
  • negative(csr) = csr

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.normalMethod.

normal(loc, scale, shape, ctx, dtype)

normal is an alias of randomnormal.

Draw random samples from a normal (Gaussian) distribution.

.. note:: The existing alias $normal$ is deprecated.

Samples are distributed according to a normal distribution parametrized by loc (mean) and scale (standard deviation).

Example::

normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], [-1.23474145, 1.55807114]]

Defined in src/operator/random/sample_op.cc:L112

Arguments

  • loc::float, optional, default=0: Mean of the distribution.
  • scale::float, optional, default=1: Standard deviation of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.one_hotMethod.

one_hot(indices, depth, on_value, off_value, dtype)

Returns a one-hot array.

The locations represented by indices take value on_value, while all other locations take value off_value.

one_hot operation with indices of shape $(i0, i1)$ and depth of $d$ would result in an output array of shape $(i0, i1, d)$ with::

output[i,j,:] = offvalue output[i,j,indices[i,j]] = onvalue

Examples::

one_hot([1,0,2,0], 3) = [[ 0. 1. 0.] [ 1. 0. 0.] [ 0. 0. 1.] [ 1. 0. 0.]]

onehot([1,0,2,0], 3, onvalue=8, off_value=1, dtype='int32') = [[1 8 1] [8 1 1] [1 1 8] [8 1 1]]

one_hot([[1,0],[1,0],[2,0]], 3) = [[[ 0. 1. 0.] [ 1. 0. 0.]]

                                 [[ 0.  1.  0.]
                                  [ 1.  0.  0.]]

                                 [[ 0.  0.  1.]
                                  [ 1.  0.  0.]]]

Defined in src/operator/tensor/indexing_op.cc:L882

Arguments

  • indices::NDArray-or-SymbolicNode: array of locations where to set on_value
  • depth::int, required: Depth of the one hot dimension.
  • on_value::double, optional, default=1: The value assigned to the locations represented by indices.
  • off_value::double, optional, default=0: The value assigned to the locations not represented by indices.
  • dtype::{'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},optional, default='float32': DType of the output

source

# MXNet.mx.ones_likeMethod.

ones_like(data)

Return an array of ones with the same shape and type as the input array.

Examples::

x = [[ 0., 0., 0.], [ 0., 0., 0.]]

ones_like(x) = [[ 1., 1., 1.], [ 1., 1., 1.]]

Arguments

  • data::NDArray-or-SymbolicNode: The input

source

# MXNet.mx.padMethod.

pad(data, mode, pad_width, constant_value)

pad is an alias of Pad.

Pads an input array with a constant or edge values of the array.

.. note:: Pad is deprecated. Use pad instead.

.. note:: Current implementation only supports 4D and 5D input arrays with padding applied only on axes 1, 2 and 3. Expects axes 4 and 5 in pad_width to be zero.

This operation pads an input array with either a constant_value or edge values along each axis of the input array. The amount of padding is specified by pad_width.

pad_width is a tuple of integer padding widths for each axis of the format $(before_1, after_1, ... , before_N, after_N)$. The pad_width should be of length $2*N$ where $N$ is the number of dimensions of the array.

For dimension $N$ of the input array, $before_N$ and $after_N$ indicates how many values to add before and after the elements of the array along dimension $N$. The widths of the higher two dimensions $before_1$, $after_1$, $before_2$, $after_2$ must be 0.

Example::

x = [[[[ 1. 2. 3.] [ 4. 5. 6.]]

     [[  7.   8.   9.]
      [ 10.  11.  12.]]]


    [[[ 11.  12.  13.]
      [ 14.  15.  16.]]

     [[ 17.  18.  19.]
      [ 20.  21.  22.]]]]

pad(x,mode="edge", pad_width=(0,0,0,0,1,1,1,1)) =

     [[[[  1.   1.   2.   3.   3.]
        [  1.   1.   2.   3.   3.]
        [  4.   4.   5.   6.   6.]
        [  4.   4.   5.   6.   6.]]

       [[  7.   7.   8.   9.   9.]
        [  7.   7.   8.   9.   9.]
        [ 10.  10.  11.  12.  12.]
        [ 10.  10.  11.  12.  12.]]]


      [[[ 11.  11.  12.  13.  13.]
        [ 11.  11.  12.  13.  13.]
        [ 14.  14.  15.  16.  16.]
        [ 14.  14.  15.  16.  16.]]

       [[ 17.  17.  18.  19.  19.]
        [ 17.  17.  18.  19.  19.]
        [ 20.  20.  21.  22.  22.]
        [ 20.  20.  21.  22.  22.]]]]

pad(x, mode="constant", constantvalue=0, padwidth=(0,0,0,0,1,1,1,1)) =

     [[[[  0.   0.   0.   0.   0.]
        [  0.   1.   2.   3.   0.]
        [  0.   4.   5.   6.   0.]
        [  0.   0.   0.   0.   0.]]

       [[  0.   0.   0.   0.   0.]
        [  0.   7.   8.   9.   0.]
        [  0.  10.  11.  12.   0.]
        [  0.   0.   0.   0.   0.]]]


      [[[  0.   0.   0.   0.   0.]
        [  0.  11.  12.  13.   0.]
        [  0.  14.  15.  16.   0.]
        [  0.   0.   0.   0.   0.]]

       [[  0.   0.   0.   0.   0.]
        [  0.  17.  18.  19.   0.]
        [  0.  20.  21.  22.   0.]
        [  0.   0.   0.   0.   0.]]]]

Defined in src/operator/pad.cc:L765

Arguments

  • data::NDArray-or-SymbolicNode: An n-dimensional input array.
  • mode::{'constant', 'edge', 'reflect'}, required: Padding type to use. "constant" pads with constant_value "edge" pads using the edge values of the input array "reflect" pads by reflecting values with respect to the edges.
  • pad_width::Shape(tuple), required: Widths of the padding regions applied to the edges of each axis. It is a tuple of integer padding widths for each axis of the format $(before_1, after_1, ... , before_N, after_N)$. It should be of length $2*N$ where $N$ is the number of dimensions of the array.This is equivalent to pad_width in numpy.pad, but flattened.
  • constant_value::double, optional, default=0: The value used for padding when mode is "constant".

source

# MXNet.mx.pickMethod.

pick(data, index, axis, keepdims, mode)

Picks elements from an input array according to the input indices along the given axis.

Given an input array of shape $(d0, d1)$ and indices of shape $(i0,)$, the result will be an output array of shape $(i0,)$ with::

output[i] = input[i, indices[i]]

By default, if any index mentioned is too large, it is replaced by the index that addresses the last element along an axis (the clip mode).

This function supports n-dimensional input and (n-1)-dimensional indices arrays.

Examples::

x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]]

// picks elements with specified indices along axis 0 pick(x, y=[0,1], 0) = [ 1., 4.]

// picks elements with specified indices along axis 1 pick(x, y=[0,1,0], 1) = [ 1., 4., 5.]

// picks elements with specified indices along axis 1 using 'wrap' mode // to place indicies that would normally be out of bounds pick(x, y=[2,-1,-2], 1, mode='wrap') = [ 1., 4., 5.]

y = [[ 1.], [ 0.], [ 2.]]

// picks elements with specified indices along axis 1 and dims are maintained pick(x, y, 1, keepdims=True) = [[ 2.], [ 3.], [ 6.]]

Defined in src/operator/tensor/broadcastreduceop_index.cc:L150

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • index::NDArray-or-SymbolicNode: The index array
  • axis::int or None, optional, default='-1': int or None. The axis to picking the elements. Negative values means indexing from right to left. If is None, the elements in the index w.r.t the flattened input will be picked.
  • keepdims::boolean, optional, default=0: If true, the axis where we pick the elements is left in the result as dimension with size one.
  • mode::{'clip', 'wrap'},optional, default='clip': Specify how out-of-bound indices behave. Default is "clip". "clip" means clip to the range. So, if all indices mentioned are too large, they are replaced by the index that addresses the last element along an axis. "wrap" means to wrap around.

source

# MXNet.mx.preloaded_multi_mp_sgd_mom_updateMethod.

preloaded_multi_mp_sgd_mom_update(data, momentum, rescale_grad, clip_gradient, num_weights)

Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer.

Momentum update has better convergence rates on neural networks. Mathematically it looks like below:

.. math::

v1 = \alpha * \nabla J(W0)\ vt = \gamma v - \alpha * \nabla J(W{t-1})\ Wt = W{t-1} + vt

It updates the weights using::

v = momentum * v - learning_rate * gradient weight += v

Where the parameter $momentum$ is the decay rate of momentum estimates at each epoch.

Defined in src/operator/contrib/preloadedmultisgd.cc:L199

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights, gradients, momentums, learning rates and weight decays
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.preloaded_multi_mp_sgd_updateMethod.

preloaded_multi_mp_sgd_update(data, rescale_grad, clip_gradient, num_weights)

Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer.

It updates the weights using::

weight = weight - learning_rate * (gradient + wd * weight)

Defined in src/operator/contrib/preloadedmultisgd.cc:L139

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights, gradients, learning rates and weight decays
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.preloaded_multi_sgd_mom_updateMethod.

preloaded_multi_sgd_mom_update(data, momentum, rescale_grad, clip_gradient, num_weights)

Momentum update function for Stochastic Gradient Descent (SGD) optimizer.

Momentum update has better convergence rates on neural networks. Mathematically it looks like below:

.. math::

v1 = \alpha * \nabla J(W0)\ vt = \gamma v - \alpha * \nabla J(W{t-1})\ Wt = W{t-1} + vt

It updates the weights using::

v = momentum * v - learning_rate * gradient weight += v

Where the parameter $momentum$ is the decay rate of momentum estimates at each epoch.

Defined in src/operator/contrib/preloadedmultisgd.cc:L90

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights, gradients, momentum, learning rates and weight decays
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.preloaded_multi_sgd_updateMethod.

preloaded_multi_sgd_update(data, rescale_grad, clip_gradient, num_weights)

Update function for Stochastic Gradient Descent (SDG) optimizer.

It updates the weights using::

weight = weight - learning_rate * (gradient + wd * weight)

Defined in src/operator/contrib/preloadedmultisgd.cc:L41

Arguments

  • data::NDArray-or-SymbolicNode[]: Weights, gradients, learning rates and weight decays
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • num_weights::int, optional, default='1': Number of updated weights.

source

# MXNet.mx.radiansMethod.

radians(data)

Converts each element of the input array from degrees to radians.

.. math:: radians([0, 90, 180, 270, 360]) = [0, \pi/2, \pi, 3\pi/2, 2\pi]

The storage type of $radians$ output depends upon the input storage type:

  • radians(default) = default
  • radians(rowsparse) = rowsparse
  • radians(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_trig.cc:L351

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.random_exponentialMethod.

random_exponential(lam, shape, ctx, dtype)

randomexponential is an alias of _randomexponential.

Draw random samples from an exponential distribution.

Samples are distributed according to an exponential distribution parametrized by lambda (rate).

Example::

exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364], [ 0.04146638, 0.31715935]]

Defined in src/operator/random/sample_op.cc:L136

Arguments

  • lam::float, optional, default=1: Lambda parameter (rate) of the exponential distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.random_gammaMethod.

random_gamma(alpha, beta, shape, ctx, dtype)

randomgamma is an alias of _randomgamma.

Draw random samples from a gamma distribution.

Samples are distributed according to a gamma distribution parametrized by alpha (shape) and beta (scale).

Example::

gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289], [ 3.91697288, 3.65933681]]

Defined in src/operator/random/sample_op.cc:L124

Arguments

  • alpha::float, optional, default=1: Alpha parameter (shape) of the gamma distribution.
  • beta::float, optional, default=1: Beta parameter (scale) of the gamma distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.random_generalized_negative_binomialMethod.

random_generalized_negative_binomial(mu, alpha, shape, ctx, dtype)

randomgeneralizednegativebinomial is an alias of _randomgeneralizednegativebinomial.

Draw random samples from a generalized negative binomial distribution.

Samples are distributed according to a generalized negative binomial distribution parametrized by mu (mean) and alpha (dispersion). alpha is defined as 1/k where k is the failure limit of the number of unsuccessful experiments (generalized to real numbers). Samples will always be returned as a floating point data type.

Example::

generalizednegativebinomial(mu=2.0, alpha=0.3, shape=(2,2)) = [[ 2., 1.], [ 6., 4.]]

Defined in src/operator/random/sample_op.cc:L178

Arguments

  • mu::float, optional, default=1: Mean of the negative binomial distribution.
  • alpha::float, optional, default=1: Alpha (dispersion) parameter of the negative binomial distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.random_negative_binomialMethod.

random_negative_binomial(k, p, shape, ctx, dtype)

randomnegativebinomial is an alias of randomnegative_binomial.

Draw random samples from a negative binomial distribution.

Samples are distributed according to a negative binomial distribution parametrized by k (limit of unsuccessful experiments) and p (failure probability in each experiment). Samples will always be returned as a floating point data type.

Example::

negative_binomial(k=3, p=0.4, shape=(2,2)) = [[ 4., 7.], [ 2., 5.]]

Defined in src/operator/random/sample_op.cc:L163

Arguments

  • k::int, optional, default='1': Limit of unsuccessful experiments.
  • p::float, optional, default=1: Failure probability in each experiment.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.random_normalMethod.

random_normal(loc, scale, shape, ctx, dtype)

randomnormal is an alias of _randomnormal.

Draw random samples from a normal (Gaussian) distribution.

.. note:: The existing alias $normal$ is deprecated.

Samples are distributed according to a normal distribution parametrized by loc (mean) and scale (standard deviation).

Example::

normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], [-1.23474145, 1.55807114]]

Defined in src/operator/random/sample_op.cc:L112

Arguments

  • loc::float, optional, default=0: Mean of the distribution.
  • scale::float, optional, default=1: Standard deviation of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.random_pdf_dirichletMethod.

random_pdf_dirichlet(sample, alpha, is_log)

randompdfdirichlet is an alias of randompdf_dirichlet.

Computes the value of the PDF of sample of Dirichlet distributions with parameter alpha.

The shape of alpha must match the leftmost subshape of sample. That is, sample can have the same shape as alpha, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the value of alpha at index i.

Examples::

random_pdf_dirichlet(sample=[[1,2],[2,3],[3,4]], alpha=[2.5, 2.5]) =
    [38.413498, 199.60245, 564.56085]

sample = [[[1, 2, 3], [10, 20, 30], [100, 200, 300]],
          [[0.1, 0.2, 0.3], [0.01, 0.02, 0.03], [0.001, 0.002, 0.003]]]

random_pdf_dirichlet(sample=sample, alpha=[0.1, 0.4, 0.9]) =
    [[2.3257459e-02, 5.8420084e-04, 1.4674458e-05],
     [9.2589635e-01, 3.6860607e+01, 1.4674468e+03]]

Defined in src/operator/random/pdf_op.cc:L315

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • alpha::NDArray-or-SymbolicNode: Concentration parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.

source

# MXNet.mx.random_pdf_exponentialMethod.

random_pdf_exponential(sample, lam, is_log)

randompdfexponential is an alias of randompdf_exponential.

Computes the value of the PDF of sample of exponential distributions with parameters lam (rate).

The shape of lam must match the leftmost subshape of sample. That is, sample can have the same shape as lam, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the value of lam at index i.

Examples::

randompdfexponential(sample=[[1, 2, 3]], lam=[1]) = [[0.36787945, 0.13533528, 0.04978707]]

sample = [[1,2,3], [1,2,3], [1,2,3]]

randompdfexponential(sample=sample, lam=[1,0.5,0.25]) = [[0.36787945, 0.13533528, 0.04978707], [0.30326533, 0.18393973, 0.11156508], [0.1947002, 0.15163267, 0.11809164]]

Defined in src/operator/random/pdf_op.cc:L304

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.

source

# MXNet.mx.random_pdf_gammaMethod.

random_pdf_gamma(sample, alpha, is_log, beta)

randompdfgamma is an alias of randompdf_gamma.

Computes the value of the PDF of sample of gamma distributions with parameters alpha (shape) and beta (rate).

alpha and beta must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as alpha and beta, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of alpha and beta at index i.

Examples::

randompdfgamma(sample=[[1,2,3,4,5]], alpha=[5], beta=[1]) = [[0.01532831, 0.09022352, 0.16803136, 0.19536681, 0.17546739]]

sample = [[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]]

randompdfgamma(sample=sample, alpha=[5,6,7], beta=[1,1,1]) = [[0.01532831, 0.09022352, 0.16803136, 0.19536681, 0.17546739], [0.03608941, 0.10081882, 0.15629345, 0.17546739, 0.16062315], [0.05040941, 0.10419563, 0.14622283, 0.16062315, 0.14900276]]

Defined in src/operator/random/pdf_op.cc:L302

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • alpha::NDArray-or-SymbolicNode: Alpha (shape) parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • beta::NDArray-or-SymbolicNode: Beta (scale) parameters of the distributions.

source

# MXNet.mx.random_pdf_generalized_negative_binomialMethod.

random_pdf_generalized_negative_binomial(sample, mu, is_log, alpha)

randompdfgeneralizednegativebinomial is an alias of randompdfgeneralizednegative_binomial.

Computes the value of the PDF of sample of generalized negative binomial distributions with parameters mu (mean) and alpha (dispersion). This can be understood as a reparameterization of the negative binomial, where k = 1 / alpha and p = 1 / (mu * alpha + 1).

mu and alpha must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as mu and alpha, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of mu and alpha at index i.

Examples::

random_pdf_generalized_negative_binomial(sample=[[1, 2, 3, 4]], alpha=[1], mu=[1]) =
    [[0.25, 0.125, 0.0625, 0.03125]]

sample = [[1,2,3,4],
          [1,2,3,4]]
random_pdf_generalized_negative_binomial(sample=sample, alpha=[1, 0.6666], mu=[1, 1.5]) =
    [[0.25,       0.125,      0.0625,     0.03125   ],
     [0.26517063, 0.16573331, 0.09667706, 0.05437994]]

Defined in src/operator/random/pdf_op.cc:L313

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • alpha::NDArray-or-SymbolicNode: Alpha (dispersion) parameters of the distributions.

source

# MXNet.mx.random_pdf_negative_binomialMethod.

random_pdf_negative_binomial(sample, k, is_log, p)

randompdfnegativebinomial is an alias of _randompdfnegativebinomial.

Computes the value of the PDF of samples of negative binomial distributions with parameters k (failure limit) and p (failure probability).

k and p must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as k and p, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of k and p at index i.

Examples::

random_pdf_negative_binomial(sample=[[1,2,3,4]], k=[1], p=a[0.5]) =
    [[0.25, 0.125, 0.0625, 0.03125]]

# Note that k may be real-valued
sample = [[1,2,3,4],
          [1,2,3,4]]
random_pdf_negative_binomial(sample=sample, k=[1, 1.5], p=[0.5, 0.5]) =
    [[0.25,       0.125,      0.0625,     0.03125   ],
     [0.26516506, 0.16572815, 0.09667476, 0.05437956]]

Defined in src/operator/random/pdf_op.cc:L309

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • k::NDArray-or-SymbolicNode: Limits of unsuccessful experiments.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • p::NDArray-or-SymbolicNode: Failure probabilities in each experiment.

source

# MXNet.mx.random_pdf_normalMethod.

random_pdf_normal(sample, mu, is_log, sigma)

randompdfnormal is an alias of randompdf_normal.

Computes the value of the PDF of sample of normal distributions with parameters mu (mean) and sigma (standard deviation).

mu and sigma must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as mu and sigma, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of mu and sigma at index i.

Examples::

sample = [[-2, -1, 0, 1, 2]]
random_pdf_normal(sample=sample, mu=[0], sigma=[1]) =
    [[0.05399097, 0.24197073, 0.3989423, 0.24197073, 0.05399097]]

random_pdf_normal(sample=sample*2, mu=[0,0], sigma=[1,2]) =
    [[0.05399097, 0.24197073, 0.3989423,  0.24197073, 0.05399097],
     [0.12098537, 0.17603266, 0.19947115, 0.17603266, 0.12098537]]

Defined in src/operator/random/pdf_op.cc:L299

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • sigma::NDArray-or-SymbolicNode: Standard deviations of the distributions.

source

# MXNet.mx.random_pdf_poissonMethod.

random_pdf_poisson(sample, lam, is_log)

randompdfpoisson is an alias of randompdf_poisson.

Computes the value of the PDF of sample of Poisson distributions with parameters lam (rate).

The shape of lam must match the leftmost subshape of sample. That is, sample can have the same shape as lam, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the value of lam at index i.

Examples::

random_pdf_poisson(sample=[[0,1,2,3]], lam=[1]) =
    [[0.36787945, 0.36787945, 0.18393973, 0.06131324]]

sample = [[0,1,2,3],
          [0,1,2,3],
          [0,1,2,3]]

random_pdf_poisson(sample=sample, lam=[1,2,3]) =
    [[0.36787945, 0.36787945, 0.18393973, 0.06131324],
     [0.13533528, 0.27067056, 0.27067056, 0.18044704],
     [0.04978707, 0.14936121, 0.22404182, 0.22404182]]

Defined in src/operator/random/pdf_op.cc:L306

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.

source

# MXNet.mx.random_pdf_uniformMethod.

random_pdf_uniform(sample, low, is_log, high)

randompdfuniform is an alias of randompdf_uniform.

Computes the value of the PDF of sample of uniform distributions on the intervals given by [low,high).

low and high must have the same shape, which must match the leftmost subshape of sample. That is, sample can have the same shape as low and high, in which case the output contains one density per distribution, or sample can be a tensor of tensors with that shape, in which case the output is a tensor of densities such that the densities at index i in the output are given by the samples at index i in sample parameterized by the values of low and high at index i.

Examples::

random_pdf_uniform(sample=[[1,2,3,4]], low=[0], high=[10]) = [0.1, 0.1, 0.1, 0.1]

sample = [[[1, 2, 3],
           [1, 2, 3]],
          [[1, 2, 3],
           [1, 2, 3]]]
low  = [[0, 0],
        [0, 0]]
high = [[ 5, 10],
        [15, 20]]
random_pdf_uniform(sample=sample, low=low, high=high) =
    [[[0.2,        0.2,        0.2    ],
      [0.1,        0.1,        0.1    ]],
     [[0.06667,    0.06667,    0.06667],
      [0.05,       0.05,       0.05   ]]]

Defined in src/operator/random/pdf_op.cc:L297

Arguments

  • sample::NDArray-or-SymbolicNode: Samples from the distributions.
  • low::NDArray-or-SymbolicNode: Lower bounds of the distributions.
  • is_log::boolean, optional, default=0: If set, compute the density of the log-probability instead of the probability.
  • high::NDArray-or-SymbolicNode: Upper bounds of the distributions.

source

# MXNet.mx.random_poissonMethod.

random_poisson(lam, shape, ctx, dtype)

randompoisson is an alias of _randompoisson.

Draw random samples from a Poisson distribution.

Samples are distributed according to a Poisson distribution parametrized by lambda (rate). Samples will always be returned as a floating point data type.

Example::

poisson(lam=4, shape=(2,2)) = [[ 5., 2.], [ 4., 6.]]

Defined in src/operator/random/sample_op.cc:L149

Arguments

  • lam::float, optional, default=1: Lambda parameter (rate) of the Poisson distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.random_randintMethod.

random_randint(low, high, shape, ctx, dtype)

randomrandint is an alias of _randomrandint.

Draw random samples from a discrete uniform distribution.

Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).

Example::

randint(low=0, high=5, shape=(2,2)) = [[ 0, 2], [ 3, 1]]

Defined in src/operator/random/sample_op.cc:L193

Arguments

  • low::long, required: Lower bound of the distribution.
  • high::long, required: Upper bound of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'int32', 'int64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to int32 if not defined (dtype=None).

source

# MXNet.mx.random_uniformMethod.

random_uniform(low, high, shape, ctx, dtype)

randomuniform is an alias of _randomuniform.

Draw random samples from a uniform distribution.

.. note:: The existing alias $uniform$ is deprecated.

Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).

Example::

uniform(low=0, high=1, shape=(2,2)) = [[ 0.60276335, 0.85794562], [ 0.54488319, 0.84725171]]

Defined in src/operator/random/sample_op.cc:L95

Arguments

  • low::float, optional, default=0: Lower bound of the distribution.
  • high::float, optional, default=1: Upper bound of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.ravel_multi_indexMethod.

ravel_multi_index(data, shape)

ravelmultiindex is an alias of ravelmulti_index.

Converts a batch of index arrays into an array of flat indices. The operator follows numpy conventions so a single multi index is given by a column of the input matrix. The leading dimension may be left unspecified by using -1 as placeholder.

Examples::

A = [[3,6,6],[4,5,1]] ravel(A, shape=(7,6)) = [22,41,37] ravel(A, shape=(-1,6)) = [22,41,37]

Defined in src/operator/tensor/ravel.cc:L41

Arguments

  • data::NDArray-or-SymbolicNode: Batch of multi-indices
  • shape::Shape(tuple), optional, default=None: Shape of the array into which the multi-indices apply.

source

# MXNet.mx.rcbrtMethod.

rcbrt(data)

Returns element-wise inverse cube-root value of the input.

.. math:: rcbrt(x) = 1/\sqrt[3]{x}

Example::

rcbrt([1,8,-125]) = [1.0, 0.5, -0.2]

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L323

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.reciprocalMethod.

reciprocal(data)

Returns the reciprocal of the argument, element-wise.

Calculates 1/x.

Example::

reciprocal([-2, 1, 3, 1.6, 0.2]) = [-0.5, 1.0, 0.33333334, 0.625, 5.0]

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L43

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.reset_arraysMethod.

reset_arrays(data, num_arrays)

Set to zero multiple arrays

Defined in src/operator/contrib/reset_arrays.cc:L35

Arguments

  • data::NDArray-or-SymbolicNode[]: Arrays
  • num_arrays::int, required: number of input arrays.

source

# MXNet.mx.reshape_likeMethod.

reshape_like(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end)

Reshape some or all dimensions of lhs to have the same shape as some or all dimensions of rhs.

Returns a view of the lhs array with a new shape without altering any data.

Example::

x = [1, 2, 3, 4, 5, 6] y = [[0, -4], [3, 2], [2, 2]] reshape_like(x, y) = [[1, 2], [3, 4], [5, 6]]

More precise control over how dimensions are inherited is achieved by specifying slices over the lhs and rhs array dimensions. Only the sliced lhs dimensions are reshaped to the rhs sliced dimensions, with the non-sliced lhs dimensions staying the same.

Examples::

  • lhs shape = (30,7), rhs shape = (15,2,4), lhsbegin=0, lhsend=1, rhsbegin=0, rhsend=2, output shape = (15,2,7)
  • lhs shape = (3, 5), rhs shape = (1,15,4), lhsbegin=0, lhsend=2, rhsbegin=1, rhsend=2, output shape = (15)

Negative indices are supported, and None can be used for either lhs_end or rhs_end to indicate the end of the range.

Example::

  • lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhsbegin=-1, lhsend=None, rhsbegin=1, rhsend=None, output shape = (30, 2, 2, 3)

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L511

Arguments

  • lhs::NDArray-or-SymbolicNode: First input.
  • rhs::NDArray-or-SymbolicNode: Second input.
  • lhs_begin::int or None, optional, default='None': Defaults to 0. The beginning index along which the lhs dimensions are to be reshaped. Supports negative indices.
  • lhs_end::int or None, optional, default='None': Defaults to None. The ending index along which the lhs dimensions are to be used for reshaping. Supports negative indices.
  • rhs_begin::int or None, optional, default='None': Defaults to 0. The beginning index along which the rhs dimensions are to be used for reshaping. Supports negative indices.
  • rhs_end::int or None, optional, default='None': Defaults to None. The ending index along which the rhs dimensions are to be used for reshaping. Supports negative indices.

source

# MXNet.mx.rintMethod.

rint(data)

Returns element-wise rounded value to the nearest integer of the input.

.. note::

  • For input $n.5$ $rint$ returns $n$ while $round$ returns $n+1$.
  • For input $-n.5$ both $rint$ and $round$ returns $-n-1$.

Example::

rint([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 1., -2., 2., 2.]

The storage type of $rint$ output depends upon the input storage type:

  • rint(default) = default
  • rint(rowsparse) = rowsparse
  • rint(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L798

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.rmsprop_updateMethod.

rmsprop_update(weight, grad, n, lr, gamma1, epsilon, wd, rescale_grad, clip_gradient, clip_weights)

Update function for RMSProp optimizer.

RMSprop is a variant of stochastic gradient descent where the gradients are divided by a cache which grows with the sum of squares of recent gradients?

RMSProp is similar to AdaGrad, a popular variant of SGD which adaptively tunes the learning rate of each parameter. AdaGrad lowers the learning rate for each parameter monotonically over the course of training. While this is analytically motivated for convex optimizations, it may not be ideal for non-convex problems. RMSProp deals with this heuristically by allowing the learning rates to rebound as the denominator decays over time.

Define the Root Mean Square (RMS) error criterion of the gradient as :math:RMS[g]_t = \sqrt{E[g^2]_t + \epsilon}, where :math:g represents gradient and :math:E[g^2]_t is the decaying average over past squared gradient.

The :math:E[g^2]_t is given by:

.. math:: E[g^2]t = \gamma * E[g^2] + (1-\gamma) * g_t^2

The update step is

.. math:: \theta{t+1} = \thetat - \frac{\eta}{RMS[g]t} gt

The RMSProp code follows the version in http://www.cs.toronto.edu/~tijmen/csc321/slides/lectureslideslec6.pdf Tieleman & Hinton, 2012.

Hinton suggests the momentum term :math:\gamma to be 0.9 and the learning rate :math:\eta to be 0.001.

Defined in src/operator/optimizer_op.cc:L796

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • n::NDArray-or-SymbolicNode: n
  • lr::float, required: Learning rate
  • gamma1::float, optional, default=0.949999988: The decay rate of momentum estimates.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • clip_weights::float, optional, default=-1: Clip weights to the range of [-clipweights, clipweights] If clipweights <= 0, weight clipping is turned off. weights = max(min(weights, clipweights), -clip_weights).

source

# MXNet.mx.rmspropalex_updateMethod.

rmspropalex_update(weight, grad, n, g, delta, lr, gamma1, gamma2, epsilon, wd, rescale_grad, clip_gradient, clip_weights)

Update function for RMSPropAlex optimizer.

RMSPropAlex is non-centered version of RMSProp.

Define :math:E[g^2]_t is the decaying average over past squared gradient and :math:E[g]_t is the decaying average over past gradient.

.. math:: E[g^2]t = \gamma1 * E[g^2]{t-1} + (1 - \gamma1) * gt^2\ E[g]t = \gamma1 * E[g] + (1 - \gamma1) * gt\ \Deltat = \gamma2 * \Delta{t-1} - \frac{\eta}{\sqrt{E[g^2]t - E[g]t^2 + \epsilon}} gt\ The update step is

.. math:: \theta{t+1} = \thetat + \Delta_t

The RMSPropAlex code follows the version in http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.

Graves suggests the momentum term :math:\gamma_1 to be 0.95, :math:\gamma_2 to be 0.9 and the learning rate :math:\eta to be 0.0001.

Defined in src/operator/optimizer_op.cc:L835

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • n::NDArray-or-SymbolicNode: n
  • g::NDArray-or-SymbolicNode: g
  • delta::NDArray-or-SymbolicNode: delta
  • lr::float, required: Learning rate
  • gamma1::float, optional, default=0.949999988: Decay rate.
  • gamma2::float, optional, default=0.899999976: Decay rate.
  • epsilon::float, optional, default=9.99999994e-09: A small constant for numerical stability.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • clip_weights::float, optional, default=-1: Clip weights to the range of [-clipweights, clipweights] If clipweights <= 0, weight clipping is turned off. weights = max(min(weights, clipweights), -clip_weights).

source

# MXNet.mx.rsqrtMethod.

rsqrt(data)

Returns element-wise inverse square-root value of the input.

.. math:: rsqrt(x) = 1/\sqrt{x}

Example::

rsqrt([4,9,16]) = [0.5, 0.33333334, 0.25]

The storage type of $rsqrt$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L221

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.sample_exponentialMethod.

sample_exponential(lam, shape, dtype)

sampleexponential is an alias of _sampleexponential.

Concurrent sampling from multiple exponential distributions with parameters lambda (rate).

The parameters of the distributions are provided as an input array. Let [s] be the shape of the input array, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input array, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input value at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input array.

Examples::

lam = [ 1.0, 8.5 ]

// Draw a single sample for each distribution sample_exponential(lam) = [ 0.51837951, 0.09994757]

// Draw a vector containing two samples for each distribution sample_exponential(lam, shape=(2)) = [[ 0.51837951, 0.19866663], [ 0.09994757, 0.50447971]]

Defined in src/operator/random/multisample_op.cc:L283

Arguments

  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.sample_gammaMethod.

sample_gamma(alpha, shape, dtype, beta)

samplegamma is an alias of _samplegamma.

Concurrent sampling from multiple gamma distributions with parameters alpha (shape) and beta (scale).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Examples::

alpha = [ 0.0, 2.5 ] beta = [ 1.0, 0.7 ]

// Draw a single sample for each distribution sample_gamma(alpha, beta) = [ 0. , 2.25797319]

// Draw a vector containing two samples for each distribution sample_gamma(alpha, beta, shape=(2)) = [[ 0. , 0. ], [ 2.25797319, 1.70734084]]

Defined in src/operator/random/multisample_op.cc:L281

Arguments

  • alpha::NDArray-or-SymbolicNode: Alpha (shape) parameters of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • beta::NDArray-or-SymbolicNode: Beta (scale) parameters of the distributions.

source

# MXNet.mx.sample_generalized_negative_binomialMethod.

sample_generalized_negative_binomial(mu, shape, dtype, alpha)

samplegeneralizednegativebinomial is an alias of _samplegeneralizednegativebinomial.

Concurrent sampling from multiple generalized negative binomial distributions with parameters mu (mean) and alpha (dispersion).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Samples will always be returned as a floating point data type.

Examples::

mu = [ 2.0, 2.5 ] alpha = [ 1.0, 0.1 ]

// Draw a single sample for each distribution samplegeneralizednegative_binomial(mu, alpha) = [ 0., 3.]

// Draw a vector containing two samples for each distribution samplegeneralizednegative_binomial(mu, alpha, shape=(2)) = [[ 0., 3.], [ 3., 1.]]

Defined in src/operator/random/multisample_op.cc:L292

Arguments

  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • alpha::NDArray-or-SymbolicNode: Alpha (dispersion) parameters of the distributions.

source

# MXNet.mx.sample_multinomialMethod.

sample_multinomial(data, shape, get_prob, dtype)

samplemultinomial is an alias of _samplemultinomial.

Concurrent sampling from multiple multinomial distributions.

data is an n dimensional array whose last dimension has length k, where k is the number of possible outcomes of each multinomial distribution. This operator will draw shape samples from each distribution. If shape is empty one sample will be drawn from each distribution.

If get_prob is true, a second array containing log likelihood of the drawn samples will also be returned. This is usually used for reinforcement learning where you can provide reward as head gradient for this array to estimate gradient.

Note that the input distribution must be normalized, i.e. data must sum to 1 along its last axis.

Examples::

probs = [[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]]

// Draw a single sample for each distribution sample_multinomial(probs) = [3, 0]

// Draw a vector containing two samples for each distribution sample_multinomial(probs, shape=(2)) = [[4, 2], [0, 0]]

// requests log likelihood samplemultinomial(probs, getprob=True) = [2, 1], [0.2, 0.3]

Arguments

  • data::NDArray-or-SymbolicNode: Distribution probabilities. Must sum to one on the last axis.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • get_prob::boolean, optional, default=0: Whether to also return the log probability of sampled result. This is usually used for differentiating through stochastic variables, e.g. in reinforcement learning.
  • dtype::{'float16', 'float32', 'float64', 'int32', 'uint8'},optional, default='int32': DType of the output in case this can't be inferred.

source

# MXNet.mx.sample_negative_binomialMethod.

sample_negative_binomial(k, shape, dtype, p)

samplenegativebinomial is an alias of samplenegative_binomial.

Concurrent sampling from multiple negative binomial distributions with parameters k (failure limit) and p (failure probability).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Samples will always be returned as a floating point data type.

Examples::

k = [ 20, 49 ] p = [ 0.4 , 0.77 ]

// Draw a single sample for each distribution samplenegativebinomial(k, p) = [ 15., 16.]

// Draw a vector containing two samples for each distribution samplenegativebinomial(k, p, shape=(2)) = [[ 15., 50.], [ 16., 12.]]

Defined in src/operator/random/multisample_op.cc:L288

Arguments

  • k::NDArray-or-SymbolicNode: Limits of unsuccessful experiments.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • p::NDArray-or-SymbolicNode: Failure probabilities in each experiment.

source

# MXNet.mx.sample_normalMethod.

sample_normal(mu, shape, dtype, sigma)

samplenormal is an alias of _samplenormal.

Concurrent sampling from multiple normal distributions with parameters mu (mean) and sigma (standard deviation).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Examples::

mu = [ 0.0, 2.5 ] sigma = [ 1.0, 3.7 ]

// Draw a single sample for each distribution sample_normal(mu, sigma) = [-0.56410581, 0.95934606]

// Draw a vector containing two samples for each distribution sample_normal(mu, sigma, shape=(2)) = [[-0.56410581, 0.2928229 ], [ 0.95934606, 4.48287058]]

Defined in src/operator/random/multisample_op.cc:L278

Arguments

  • mu::NDArray-or-SymbolicNode: Means of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • sigma::NDArray-or-SymbolicNode: Standard deviations of the distributions.

source

# MXNet.mx.sample_poissonMethod.

sample_poisson(lam, shape, dtype)

samplepoisson is an alias of _samplepoisson.

Concurrent sampling from multiple Poisson distributions with parameters lambda (rate).

The parameters of the distributions are provided as an input array. Let [s] be the shape of the input array, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input array, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input value at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input array.

Samples will always be returned as a floating point data type.

Examples::

lam = [ 1.0, 8.5 ]

// Draw a single sample for each distribution sample_poisson(lam) = [ 0., 13.]

// Draw a vector containing two samples for each distribution sample_poisson(lam, shape=(2)) = [[ 0., 4.], [ 13., 8.]]

Defined in src/operator/random/multisample_op.cc:L285

Arguments

  • lam::NDArray-or-SymbolicNode: Lambda (rate) parameters of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.sample_uniformMethod.

sample_uniform(low, shape, dtype, high)

sampleuniform is an alias of _sampleuniform.

Concurrent sampling from multiple uniform distributions on the intervals given by [low,high).

The parameters of the distributions are provided as input arrays. Let [s] be the shape of the input arrays, n be the dimension of [s], [t] be the shape specified as the parameter of the operator, and m be the dimension of [t]. Then the output will be a (n+m)-dimensional array with shape [s]x[t].

For any valid n-dimensional index i with respect to the input arrays, output[i] will be an m-dimensional array that holds randomly drawn samples from the distribution which is parameterized by the input values at index i. If the shape parameter of the operator is not set, then one sample will be drawn per distribution and the output array has the same shape as the input arrays.

Examples::

low = [ 0.0, 2.5 ] high = [ 1.0, 3.7 ]

// Draw a single sample for each distribution sample_uniform(low, high) = [ 0.40451524, 3.18687344]

// Draw a vector containing two samples for each distribution sample_uniform(low, high, shape=(2)) = [[ 0.40451524, 0.18017688], [ 3.18687344, 3.68352246]]

Defined in src/operator/random/multisample_op.cc:L276

Arguments

  • low::NDArray-or-SymbolicNode: Lower bounds of the distributions.
  • shape::Shape(tuple), optional, default=[]: Shape to be sampled from each random distribution.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).
  • high::NDArray-or-SymbolicNode: Upper bounds of the distributions.

source

# MXNet.mx.scatter_ndMethod.

scatter_nd(data, indices, shape)

Scatters data into a new tensor according to indices.

Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape (M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}), where M <= N. If M == N, data shape should simply be (Y_0, ..., Y_{K-1}).

The elements in output is defined as follows::

output[indices[0, y0, ..., y], ..., indices[M-1, y0, ..., y], xM, ..., x] = data[y0, ..., y, xM, ..., x]

all other entries in output are 0.

.. warning::

If the indices have duplicates, the result will be non-deterministic and
the gradient of `scatter_nd` will not be correct!!

Examples::

data = [2, 3, 0] indices = [[1, 1, 0], [0, 1, 0]] shape = (2, 2) scatter_nd(data, indices, shape) = [[0, 0], [2, 3]]

data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] indices = [[0, 1], [1, 1]] shape = (2, 2, 2, 2) scatter_nd(data, indices, shape) = [[[[0, 0], [0, 0]],

                                   [[1, 2],
                                    [3, 4]]],

                                  [[[0, 0],
                                    [0, 0]],

                                   [[5, 6],
                                    [7, 8]]]]

Arguments

  • data::NDArray-or-SymbolicNode: data
  • indices::NDArray-or-SymbolicNode: indices
  • shape::Shape(tuple), required: Shape of output.

source

# MXNet.mx.sgd_mom_updateMethod.

sgd_mom_update(weight, grad, mom, lr, momentum, wd, rescale_grad, clip_gradient, lazy_update)

Momentum update function for Stochastic Gradient Descent (SGD) optimizer.

Momentum update has better convergence rates on neural networks. Mathematically it looks like below:

.. math::

v1 = \alpha * \nabla J(W0)\ vt = \gamma v - \alpha * \nabla J(W{t-1})\ Wt = W{t-1} + vt

It updates the weights using::

v = momentum * v - learning_rate * gradient weight += v

Where the parameter $momentum$ is the decay rate of momentum estimates at each epoch.

However, if grad's storage type is $row_sparse$, $lazy_update$ is True and weight's storage type is the same as momentum's storage type, only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::

for row in gradient.indices: v[row] = momentum[row] * v[row] - learning_rate * gradient[row] weight[row] += v[row]

Defined in src/operator/optimizer_op.cc:L564

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mom::NDArray-or-SymbolicNode: Momentum
  • lr::float, required: Learning rate
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse and both weight and momentum have the same stype

source

# MXNet.mx.sgd_updateMethod.

sgd_update(weight, grad, lr, wd, rescale_grad, clip_gradient, lazy_update)

Update function for Stochastic Gradient Descent (SGD) optimizer.

It updates the weights using::

weight = weight - learning_rate * (gradient + wd * weight)

However, if gradient is of $row_sparse$ storage type and $lazy_update$ is True, only the row slices whose indices appear in grad.indices are updated::

for row in gradient.indices: weight[row] = weight[row] - learning_rate * (gradient[row] + wd * weight[row])

Defined in src/operator/optimizer_op.cc:L523

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • lr::float, required: Learning rate
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • lazy_update::boolean, optional, default=1: If true, lazy updates are applied if gradient's stype is row_sparse.

source

# MXNet.mx.shape_arrayMethod.

shape_array(data)

Returns a 1D int64 array containing the shape of data.

Example::

shape_array([[1,2,3,4], [5,6,7,8]]) = [2,4]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L573

Arguments

  • data::NDArray-or-SymbolicNode: Input Array.

source

# MXNet.mx.signsgd_updateMethod.

signsgd_update(weight, grad, lr, wd, rescale_grad, clip_gradient)

Update function for SignSGD optimizer.

.. math::

gt = \nabla J(W)\ Wt = W - \etat \text{sign}(gt)

It updates the weights using::

weight = weight - learning_rate * sign(gradient)

.. note::

  • sparse ndarray not supported for this optimizer yet.

Defined in src/operator/optimizer_op.cc:L62

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • lr::float, required: Learning rate
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).

source

# MXNet.mx.signum_updateMethod.

signum_update(weight, grad, mom, lr, momentum, wd, rescale_grad, clip_gradient, wd_lh)

SIGN momentUM (Signum) optimizer.

.. math::

gt = \nabla J(W)\ mt = \beta m + (1 - \beta) gt\ Wt = W{t-1} - \etat \text{sign}(m_t)

It updates the weights using:: state = momentum * state + (1-momentum) * gradient weight = weight - learning_rate * sign(state)

Where the parameter $momentum$ is the decay rate of momentum estimates at each epoch.

.. note::

  • sparse ndarray not supported for this optimizer yet.

Defined in src/operator/optimizer_op.cc:L91

Arguments

  • weight::NDArray-or-SymbolicNode: Weight
  • grad::NDArray-or-SymbolicNode: Gradient
  • mom::NDArray-or-SymbolicNode: Momentum
  • lr::float, required: Learning rate
  • momentum::float, optional, default=0: The decay rate of momentum estimates at each epoch.
  • wd::float, optional, default=0: Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
  • rescale_grad::float, optional, default=1: Rescale gradient to grad = rescale_grad*grad.
  • clip_gradient::float, optional, default=-1: Clip gradient to the range of [-clipgradient, clipgradient] If clipgradient <= 0, gradient clipping is turned off. grad = max(min(grad, clipgradient), -clip_gradient).
  • wd_lh::float, optional, default=0: The amount of weight decay that does not go into gradient/momentum calculationsotherwise do weight decay algorithmically only.

source

# MXNet.mx.size_arrayMethod.

size_array(data)

Returns a 1D int64 array containing the size of data.

Example::

size_array([[1,2,3,4], [5,6,7,8]]) = [8]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L624

Arguments

  • data::NDArray-or-SymbolicNode: Input Array.

source

# MXNet.mx.sliceMethod.

slice(data, begin, end, step)

Slices a region of the array. .. note:: $crop$ is deprecated. Use $slice$ instead. This function returns a sliced array between the indices given by begin and end with the corresponding step. For an input array of $shape=(d_0, d_1, ..., d_n-1)$, slice operation with $begin=(b_0, b_1...b_m-1)$, $end=(e_0, e_1, ..., e_m-1)$, and $step=(s_0, s_1, ..., s_m-1)$, where m <= n, results in an array with the shape $(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)$. The resulting array's k-th dimension contains elements from the k-th dimension of the input array starting from index $b_k$ (inclusive) with step $s_k$ until reaching $e_k$ (exclusive). If the k-th elements are None in the sequence of begin, end, and step, the following rule will be used to set default values. If s_k is None, set s_k=1. If s_k > 0, set b_k=0, e_k=d_k; else, set b_k=d_k-1, e_k=-1. The storage type of $slice$ output depends on storage types of inputs

  • slice(csr) = csr
  • otherwise, $slice$ generates output with default storage

.. note:: When input data storage type is csr, it only supports step=(), or step=(None,), or step=(1,) to generate a csr output. For other step parameter values, it falls back to slicing a dense tensor. Example:: x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] slice(x, begin=(0,1), end=(2,4)) = [[ 2., 3., 4.], [ 6., 7., 8.]] slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.], [5., 7.], [1., 3.]]

Defined in src/operator/tensor/matrix_op.cc:L481

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • begin::Shape(tuple), required: starting indices for the slice operation, supports negative indices.
  • end::Shape(tuple), required: ending indices for the slice operation, supports negative indices.
  • step::Shape(tuple), optional, default=[]: step for the slice operation, supports negative values.

source

# MXNet.mx.slice_axisMethod.

slice_axis(data, axis, begin, end)

Slices along a given axis. Returns an array slice along a given axis starting from the begin index to the end index. Examples:: x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] sliceaxis(x, axis=0, begin=1, end=3) = [[ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] sliceaxis(x, axis=1, begin=0, end=2) = [[ 1., 2.], [ 5., 6.], [ 9., 10.]] slice_axis(x, axis=1, begin=-3, end=-1) = [[ 2., 3.], [ 6., 7.], [ 10., 11.]]

Defined in src/operator/tensor/matrix_op.cc:L570

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • axis::int, required: Axis along which to be sliced, supports negative indexes.
  • begin::int, required: The beginning index along the axis to be sliced, supports negative indexes.
  • end::int or None, required: The ending index along the axis to be sliced, supports negative indexes.

source

# MXNet.mx.slice_likeMethod.

slice_like(data, shape_like, axes)

Slices a region of the array like the shape of another array. This function is similar to $slice$, however, the begin are always 0s and end of specific axes are inferred from the second input shape_like. Given the second shape_like input of $shape=(d_0, d_1, ..., d_n-1)$, a $slice_like$ operator with default empty axes, it performs the following operation: $out = slice(input, begin=(0, 0, ..., 0), end=(d_0, d_1, ..., d_n-1))$. When axes is not empty, it is used to speficy which axes are being sliced. Given a 4-d input data, $slice_like$ operator with $axes=(0, 2, -1)$ will perform the following operation: $out = slice(input, begin=(0, 0, 0, 0), end=(d_0, None, d_2, d_3))$. Note that it is allowed to have first and second input with different dimensions, however, you have to make sure the axes are specified and not exceeding the dimension limits. For example, given input_1 with $shape=(2,3,4,5)$ and input_2 with $shape=(1,2,3)$, it is not allowed to use: $out = slice_like(a, b)$ because ndim of input_1 is 4, and ndim of input_2 is 3. The following is allowed in this situation: $out = slice_like(a, b, axes=(0, 2))$ Example:: x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] y = [[ 0., 0., 0.], [ 0., 0., 0.]] slicelike(x, y) = [[ 1., 2., 3.] [ 5., 6., 7.]] slicelike(x, y, axes=(0, 1)) = [[ 1., 2., 3.] [ 5., 6., 7.]] slicelike(x, y, axes=(0)) = [[ 1., 2., 3., 4.] [ 5., 6., 7., 8.]] slicelike(x, y, axes=(-1)) = [[ 1., 2., 3.] [ 5., 6., 7.] [ 9., 10., 11.]]

Defined in src/operator/tensor/matrix_op.cc:L624

Arguments

  • data::NDArray-or-SymbolicNode: Source input
  • shape_like::NDArray-or-SymbolicNode: Shape like input
  • axes::Shape(tuple), optional, default=[]: List of axes on which input data will be sliced according to the corresponding size of the second input. By default will slice on all axes. Negative axes are supported.

source

# MXNet.mx.smooth_l1Method.

smooth_l1(data, scalar)

Calculate Smooth L1 Loss(lhs, scalar) by summing

.. math::

f(x) =
\begin{cases}
(\sigma x)^2/2,& \text{if }x < 1/\sigma^2\\
|x|-0.5/\sigma^2,& \text{otherwise}
\end{cases}

where :math:x is an element of the tensor lhs and :math:\sigma is the scalar.

Example::

smoothl1([1, 2, 3, 4]) = [0.5, 1.5, 2.5, 3.5] smoothl1([1, 2, 3, 4], scalar=1) = [0.5, 1.5, 2.5, 3.5]

Defined in src/operator/tensor/elemwisebinaryscalaropextended.cc:L108

Arguments

  • data::NDArray-or-SymbolicNode: source input
  • scalar::float: scalar input

source

# MXNet.mx.softmax_cross_entropyMethod.

softmax_cross_entropy(data, label)

Calculate cross entropy of softmax output and one-hot label.

  • This operator computes the cross entropy in two steps:

    • Applies softmax function on the input array.
    • Computes and returns the cross entropy loss between the softmax output and the labels.
    • The softmax function and cross entropy loss is given by:

    • Softmax Function:

    .. math:: \text{softmax}(x)i = \frac{exp(xi)}{\sumj exp(xj)}

    • Cross Entropy Function:

    .. math:: \text{CE(label, output)} = - \sumi \text{label}i \log(\text{output}_i)

Example::

x = [[1, 2, 3], [11, 7, 5]]

label = [2, 0]

softmax(x) = [[0.09003057, 0.24472848, 0.66524094], [0.97962922, 0.01794253, 0.00242826]]

softmaxcrossentropy(data, label) = - log(0.66524084) - log(0.97962922) = 0.4281871

Defined in src/operator/lossbinaryop.cc:L58

Arguments

  • data::NDArray-or-SymbolicNode: Input data
  • label::NDArray-or-SymbolicNode: Input label

source

# MXNet.mx.softminMethod.

softmin(data, axis, temperature, dtype, use_length)

Applies the softmin function.

The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1.

.. math:: softmin(\mathbf{z/t})j = \frac{e^{-zj/t}}{\sum{k=1}^K e^{-zk/t}}

for :math:j = 1, ..., K

t is the temperature parameter in softmax function. By default, t equals 1.0

Example::

x = [[ 1. 2. 3.] [ 3. 2. 1.]]

softmin(x,axis=0) = [[ 0.88079703, 0.5, 0.11920292], [ 0.11920292, 0.5, 0.88079703]]

softmin(x,axis=1) = [[ 0.66524094, 0.24472848, 0.09003057], [ 0.09003057, 0.24472848, 0.66524094]]

Defined in src/operator/nn/softmin.cc:L56

Arguments

  • data::NDArray-or-SymbolicNode: The input array.
  • axis::int, optional, default='-1': The axis along which to compute softmax.
  • temperature::double or None, optional, default=None: Temperature parameter in softmax
  • dtype::{None, 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None).
  • use_length::boolean or None, optional, default=0: Whether to use the length input as a mask over the data input.

source

# MXNet.mx.softsignMethod.

softsign(data)

Computes softsign of x element-wise.

.. math:: y = x / (1 + abs(x))

The storage type of $softsign$ output is always dense

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L191

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.space_to_depthMethod.

space_to_depth(data, block_size)

Rearranges(permutes) blocks of spatial data into depth. Similar to ONNX SpaceToDepth operator: https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth The output is a new tensor where the values from height and width dimension are moved to the depth dimension. The reverse of this operation is $depth_to_space$. .. math:: where :math:x is an input tensor with default layout as :math:[N, C, H, W]: [batch, channels, height, width] and :math:y is the output tensor of layout :math:[N, C * (block\_size ^ 2), H / block\_size, W / block\_size] Example:: x = [[[[0, 6, 1, 7, 2, 8], [12, 18, 13, 19, 14, 20], [3, 9, 4, 10, 5, 11], [15, 21, 16, 22, 17, 23]]]] spacetodepth(x, 2) = [[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]], [[18, 19, 20], [21, 22, 23]]]]

Defined in src/operator/tensor/matrix_op.cc:L1018

Arguments

  • data::NDArray-or-SymbolicNode: Input ndarray
  • block_size::int, required: Blocks of [blocksize. blocksize] are moved

source

# MXNet.mx.squareMethod.

square(data)

Returns element-wise squared value of the input.

.. math:: square(x) = x^2

Example::

square([2, 3, 4]) = [4, 9, 16]

The storage type of $square$ output depends upon the input storage type:

  • square(default) = default
  • square(rowsparse) = rowsparse
  • square(csr) = csr

Defined in src/operator/tensor/elemwiseunaryop_pow.cc:L119

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.squeezeMethod.

squeeze(data, axis)

Remove single-dimensional entries from the shape of an array. Same behavior of defining the output tensor shape as numpy.squeeze for the most of cases. See the following note for exception. Examples:: data = [[[0], [1], [2]]] squeeze(data) = [0, 1, 2] squeeze(data, axis=0) = [[0], [1], [2]] squeeze(data, axis=2) = [[0, 1, 2]] squeeze(data, axis=(0, 2)) = [0, 1, 2] .. Note:: The output of this operator will keep at least one dimension not removed. For example, squeeze([[[4]]]) = [4], while in numpy.squeeze, the output will become a scalar.

Arguments

  • data::NDArray-or-SymbolicNode: data to squeeze
  • axis::Shape or None, optional, default=None: Selects a subset of the single-dimensional entries in the shape. If an axis is selected with shape entry greater than one, an error is raised.

source

# MXNet.mx.stackMethod.

stack(data, axis, num_args)

Note: stack takes variable number of positional inputs. So instead of calling as stack([x, y, z], numargs=3), one should call via stack(x, y, z), and numargs will be determined automatically.

Join a sequence of arrays along a new axis. The axis parameter specifies the index of the new axis in the dimensions of the result. For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last dimension. Examples:: x = [1, 2] y = [3, 4] stack(x, y) = [[1, 2], [3, 4]] stack(x, y, axis=1) = [[1, 3], [2, 4]]

Arguments

  • data::NDArray-or-SymbolicNode[]: List of arrays to stack
  • axis::int, optional, default='0': The axis in the result array along which the input arrays are stacked.
  • num_args::int, required: Number of inputs to be stacked.

source

# MXNet.mx.stop_gradientMethod.

stop_gradient(data)

stop_gradient is an alias of BlockGrad.

Stops gradient computation.

Stops the accumulated gradient of the inputs from flowing through this operator in the backward direction. In other words, this operator prevents the contribution of its inputs to be taken into account for computing gradients.

Example::

v1 = [1, 2] v2 = [0, 1] a = Variable('a') b = Variable('b') bstopgrad = stopgradient(3 * b) loss = MakeLoss(bstop_grad + a)

executor = loss.simplebind(ctx=cpu(), a=(1,2), b=(1,2)) executor.forward(istrain=True, a=v1, b=v2) executor.outputs [ 1. 5.]

executor.backward() executor.grad_arrays [ 0. 0.] [ 1. 1.]

Defined in src/operator/tensor/elemwiseunaryop_basic.cc:L325

Arguments

  • data::NDArray-or-SymbolicNode: The input array.

source

# MXNet.mx.sum_axisMethod.

sum_axis(data, axis, keepdims, exclude)

sum_axis is an alias of sum.

Computes the sum of array elements over given axes.

.. Note::

sum and sum_axis are equivalent. For ndarray of csr storage type summation along axis 0 and axis 1 is supported. Setting keepdims or exclude to True will cause a fallback to dense operator.

Example::

data = [[[1, 2], [2, 3], [1, 3]], [[1, 4], [4, 3], [5, 2]], [[7, 1], [7, 2], [7, 3]]]

sum(data, axis=1) [[ 4. 8.] [ 10. 9.] [ 21. 6.]]

sum(data, axis=[1,2]) [ 12. 19. 27.]

data = [[1, 2, 0], [3, 0, 1], [4, 1, 0]]

csr = cast_storage(data, 'csr')

sum(csr, axis=0) [ 8. 3. 1.]

sum(csr, axis=1) [ 3. 4. 5.]

Defined in src/operator/tensor/broadcastreducesum_value.cc:L66

Arguments

  • data::NDArray-or-SymbolicNode: The input
  • axis::Shape or None, optional, default=None: The axis or axes along which to perform the reduction.

    `` The default,axis=(), will compute over all elements into a scalar array with shape(1,)`.

    If axis is int, a reduction is performed on a particular axis.

    If axis is a tuple of ints, a reduction is performed on all the axes specified in the tuple.

    If exclude is true, reduction will be performed on the axes that are NOT in axis instead.

    Negative values means indexing from right to left. `` *keepdims::boolean, optional, default=0: If this is set toTrue, the reduced axes are left in the result as dimension with size one. *exclude::boolean, optional, default=0`: Whether to perform reduction on axis that are NOT in axis instead.

source

# MXNet.mx.swapaxesMethod.

swapaxes(data, dim1, dim2)

swapaxes is an alias of SwapAxis.

Interchanges two axes of an array.

Examples::

x = [[1, 2, 3]]) swapaxes(x, 0, 1) = [[ 1], [ 2], [ 3]]

x = [[[ 0, 1], [ 2, 3]], [[ 4, 5], [ 6, 7]]] // (2,2,2) array

swapaxes(x, 0, 2) = [[[ 0, 4], [ 2, 6]], [[ 1, 5], [ 3, 7]]]

Defined in src/operator/swapaxis.cc:L69

Arguments

  • data::NDArray-or-SymbolicNode: Input array.
  • dim1::int, optional, default='0': the first axis to be swapped.
  • dim2::int, optional, default='0': the second axis to be swapped.

source

# MXNet.mx.takeMethod.

take(a, indices, axis, mode)

Takes elements from an input array along the given axis.

This function slices the input array along a particular axis with the provided indices.

Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them in an output tensor of rank q + (r - 1).

Examples::

x = [4. 5. 6.]

// Trivial case, take the second element along the first axis.

take(x, [1]) = [ 5. ]

// The other trivial case, axis=-1, take the third element along the first axis

take(x, [3], axis=-1, mode='clip') = [ 6. ]

x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]]

// In this case we will get rows 0 and 1, then 1 and 2. Along axis 0

take(x, [[0,1],[1,2]]) = [[[ 1., 2.], [ 3., 4.]],

                        [[ 3.,  4.],
                         [ 5.,  6.]]]

// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around). // Along axis 1

take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1. 2.] [ 2. 1.]]

                                                  [[ 3.  4.]
                                                   [ 4.  3.]]

                                                  [[ 5.  6.]
                                                   [ 6.  5.]]]

The storage type of $take$ output depends upon the input storage type:

  • take(default, default) = default
  • take(csr, default, axis=0) = csr

Defined in src/operator/tensor/indexing_op.cc:L776

Arguments

  • a::NDArray-or-SymbolicNode: The input array.
  • indices::NDArray-or-SymbolicNode: The indices of the values to be extracted.
  • axis::int, optional, default='0': The axis of input array to be taken.For input tensor of rank r, it could be in the range of [-r, r-1]
  • mode::{'clip', 'raise', 'wrap'},optional, default='clip': Specify how out-of-bound indices bahave. Default is "clip". "clip" means clip to the range. So, if all indices mentioned are too large, they are replaced by the index that addresses the last element along an axis. "wrap" means to wrap around. "raise" means to raise an error when index out of range.

source

# MXNet.mx.tileMethod.

tile(data, reps)

Repeats the whole array multiple times. If $reps$ has length d, and input array has dimension of n. There are three cases:

  • n=d. Repeat i-th dimension of the input by $reps[i]$ times:: x = [[1, 2], [3, 4]] tile(x, reps=(2,3)) = [[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]]
  • n>d. $reps$ is promoted to length n by pre-pending 1's to it. Thus for an input shape $(2,3)$, $repos=(2,)$ is treated as $(1,2)$:: tile(x, reps=(2,)) = [[ 1., 2., 1., 2.], [ 3., 4., 3., 4.]]
  • n<d. The input is promoted to be d-dimensional by prepending new axes. So a shape $(2,2)$ array is promoted to $(1,2,2)$ for 3-D replication:: tile(x, reps=(2,2,3)) = [[[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]], [[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]]]

Defined in src/operator/tensor/matrix_op.cc:L795

Arguments

  • data::NDArray-or-SymbolicNode: Input data array
  • reps::Shape(tuple), required: The number of times for repeating the tensor a. Each dim size of reps must be a positive integer. If reps has length d, the result will have dimension of max(d, a.ndim); If a.ndim < d, a is promoted to be d-dimensional by prepending new axes. If a.ndim > d, reps is promoted to a.ndim by pre-pending 1's to it.

source

# MXNet.mx.topkMethod.

topk(data, axis, k, ret_typ, is_ascend, dtype)

Returns the indices of the top k elements in an input array along the given axis (by default). If rettype is set to 'value' returns the value of top k elements (instead of indices). In case of rettype = 'both', both value and index would be returned. The returned elements will be sorted.

Examples::

x = [[ 0.3, 0.2, 0.4], [ 0.1, 0.3, 0.2]]

// returns an index of the largest element on last axis topk(x) = [[ 2.], [ 1.]]

// returns the value of top-2 largest elements on last axis topk(x, ret_typ='value', k=2) = [[ 0.4, 0.3], [ 0.3, 0.2]]

// returns the value of top-2 smallest elements on last axis topk(x, rettyp='value', k=2, isascend=1) = [[ 0.2 , 0.3], [ 0.1 , 0.2]]

// returns the value of top-2 largest elements on axis 0 topk(x, axis=0, ret_typ='value', k=2) = [[ 0.3, 0.3, 0.4], [ 0.1, 0.2, 0.2]]

// flattens and then returns list of both values and indices topk(x, ret_typ='both', k=2) = [[[ 0.4, 0.3], [ 0.3, 0.2]] , [[ 2., 0.], [ 1., 2.]]]

Defined in src/operator/tensor/ordering_op.cc:L67

Arguments

  • data::NDArray-or-SymbolicNode: The input array
  • axis::int or None, optional, default='-1': Axis along which to choose the top k indices. If not given, the flattened array is used. Default is -1.
  • k::int, optional, default='1': Number of top elements to select, should be always smaller than or equal to the element number in the given axis. A global sort is performed if set k < 1.
  • ret_typ::{'both', 'indices', 'mask', 'value'},optional, default='indices': The return type.

"value" means to return the top k values, "indices" means to return the indices of the top k values, "mask" means to return a mask array containing 0 and 1. 1 means the top k values. "both" means to return a list of both values and indices of top k elements.

  • is_ascend::boolean, optional, default=0: Whether to choose k largest or k smallest elements. Top K largest elements will be chosen if set to false.
  • dtype::{'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'},optional, default='float32': DType of the output indices when ret_typ is "indices" or "both". An error will be raised if the selected data type cannot precisely represent the indices.

source

# MXNet.mx.uniformMethod.

uniform(low, high, shape, ctx, dtype)

uniform is an alias of randomuniform.

Draw random samples from a uniform distribution.

.. note:: The existing alias $uniform$ is deprecated.

Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high).

Example::

uniform(low=0, high=1, shape=(2,2)) = [[ 0.60276335, 0.85794562], [ 0.54488319, 0.84725171]]

Defined in src/operator/random/sample_op.cc:L95

Arguments

  • low::float, optional, default=0: Lower bound of the distribution.
  • high::float, optional, default=1: Upper bound of the distribution.
  • shape::Shape(tuple), optional, default=None: Shape of the output.
  • ctx::string, optional, default='': Context of output, in format cpu|gpu|cpu_pinned. Only used for imperative calls.
  • dtype::{'None', 'float16', 'float32', 'float64'},optional, default='None': DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None).

source

# MXNet.mx.unravel_indexMethod.

unravel_index(data, shape)

unravelindex is an alias of _unravelindex.

Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix. The leading dimension may be left unspecified by using -1 as placeholder.

Examples::

A = [22,41,37] unravel(A, shape=(7,6)) = [[3,6,6],[4,5,1]] unravel(A, shape=(-1,6)) = [[3,6,6],[4,5,1]]

Defined in src/operator/tensor/ravel.cc:L67

Arguments

  • data::NDArray-or-SymbolicNode: Array of flat indices
  • shape::Shape(tuple), optional, default=None: Shape of the array into which the multi-indices apply.

source

# MXNet.mx.whereMethod.

where(condition, x, y)

Return the elements, either from x or y, depending on the condition.

Given three ndarrays, condition, x, and y, return an ndarray with the elements from x or y, depending on the elements from condition are true or false. x and y must have the same shape. If condition has the same shape as x, each element in the output array is from x if the corresponding element in the condition is true, and from y if false.

If condition does not have the same shape as x, it must be a 1D array whose size is the same as x's first dimension size. Each row of the output array is from x's row if the corresponding element from condition is true, and from y's row if false.

Note that all non-zero values are interpreted as $True$ in condition.

Examples::

x = [[1, 2], [3, 4]] y = [[5, 6], [7, 8]] cond = [[0, 1], [-1, 0]]

where(cond, x, y) = [[5, 2], [3, 8]]

csrcond = caststorage(cond, 'csr')

where(csr_cond, x, y) = [[5, 2], [3, 8]]

Defined in src/operator/tensor/controlflowop.cc:L56

Arguments

  • condition::NDArray-or-SymbolicNode: condition array
  • x::NDArray-or-SymbolicNode:
  • y::NDArray-or-SymbolicNode:

source

# MXNet.mx.zeros_likeMethod.

zeros_like(data)

Return an array of zeros with the same shape, type and storage type as the input array.

The storage type of $zeros_like$ output depends on the storage type of the input

  • zeroslike(rowsparse) = row_sparse
  • zeros_like(csr) = csr
  • zeros_like(default) = default

Examples::

x = [[ 1., 1., 1.], [ 1., 1., 1.]]

zeros_like(x) = [[ 0., 0., 0.], [ 0., 0., 0.]]

Arguments

  • data::NDArray-or-SymbolicNode: The input

source

# Random.shuffleMethod.

shuffle(data)

shuffle is an alias of _shuffle.

Randomly shuffle the elements.

This shuffles the array along the first axis. The order of the elements in each subarray does not change. For example, if a 2D array is given, the order of the rows randomly changes, but the order of the elements in each row does not change.

Arguments

  • data::NDArray-or-SymbolicNode: Data to be shuffled.

source

# Base.argmaxMethod.

argmax(x::NDArray; dims) -> indices

Note that NaN is treated as greater than all other values in argmax.

Examples

julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
 0.0  1.0  2.0
 3.0  4.0  5.0

julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
 2.0  2.0  2.0

julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
 3.0
 3.0

See also argmin.

source

# Base.argminMethod.

argmin(x::NDArray; dims) -> indices

Note that NaN is treated as less than all other values in argmin.

Examples

julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
 0.0  1.0  2.0
 3.0  4.0  5.0

julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
 2.0  2.0  2.0

julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
 3.0
 3.0

See also argmax.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.Broadcast.broadcastedMethod.

source

# Base.acosFunction.

acos.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L233`

source

# Base.acoshFunction.

acosh.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L535`

source

# Base.asinFunction.

asin.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L187`

source

# Base.asinhFunction.

asinh.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L494`

source

# Base.atanFunction.

atan.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L282`

source

# Base.atanhFunction.

atanh.(x::NDArray)Defined in `src/operator/tensor/elemwise_unary_op_trig.cc:L579`

source

# Base.permutedimsMethod.

Base.permutedims(x::NDArray, axes)

Defined in src/operator/tensor/matrix_op.cc:L327

source

# Base.transposeMethod.

LinearAlgebra.transpose(x::NDArray{T, 1}) where T

Defined in src/operator/tensor/matrix_op.cc:L174

source

# Base.transposeMethod.

LinearAlgebra.transpose(x::NDArray{T, 2}) where T

Defined in src/operator/tensor/matrix_op.cc:L327

source

# LinearAlgebra.dotMethod.

LinearAlgebra.dot(x::NDArray, y::NDArray)

Defined in src/operator/tensor/dot.cc:L77

source

# MXNet.mx._argmaxMethod.

_argmax(x::NDArray, dims)

Defined in src/operator/tensor/broadcast_reduce_op_index.cc:L51

source

# MXNet.mx._argmaxMethod.

_argmax(x::NDArray, ::Colon)

Defined in src/operator/tensor/broadcast_reduce_op_index.cc:L51

source

# MXNet.mx._argminMethod.

_argmin(x::NDArray, dims)

Defined in src/operator/tensor/broadcast_reduce_op_index.cc:L76

source

# MXNet.mx._argminMethod.

_argmin(x::NDArray, ::Colon)

Defined in src/operator/tensor/broadcast_reduce_op_index.cc:L76

source

# MXNet.mx._broadcast_add!Method.

_broadcast_add!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L57

source

# MXNet.mx._broadcast_addMethod.

_broadcast_add(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L57

source

# MXNet.mx._broadcast_div!Method.

_broadcast_div!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L186

source

# MXNet.mx._broadcast_divMethod.

_broadcast_div(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L186

source

# MXNet.mx._broadcast_equal!Method.

_broadcast_equal!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L45

source

# MXNet.mx._broadcast_equalMethod.

_broadcast_equal(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L45

source

# MXNet.mx._broadcast_greater!Method.

_broadcast_greater!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L81

source

# MXNet.mx._broadcast_greaterMethod.

_broadcast_greater(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L81

source

# MXNet.mx._broadcast_greater_equal!Method.

_broadcast_greater_equal!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L99

source

# MXNet.mx._broadcast_greater_equalMethod.

_broadcast_greater_equal(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L99

source

# MXNet.mx._broadcast_hypot!Method.

_broadcast_hypot!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L157

source

# MXNet.mx._broadcast_hypotMethod.

_broadcast_hypot(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L157

source

# MXNet.mx._broadcast_lesser!Method.

_broadcast_lesser!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L117

source

# MXNet.mx._broadcast_lesserMethod.

_broadcast_lesser(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L117

source

# MXNet.mx._broadcast_lesser_equal!Method.

_broadcast_lesser_equal!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L135

source

# MXNet.mx._broadcast_lesser_equalMethod.

_broadcast_lesser_equal(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L135

source

# MXNet.mx._broadcast_maximum!Method.

_broadcast_maximum!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L80

source

# MXNet.mx._broadcast_maximumMethod.

_broadcast_maximum(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L80

source

# MXNet.mx._broadcast_minimum!Method.

_broadcast_minimum!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L116

source

# MXNet.mx._broadcast_minimumMethod.

_broadcast_minimum(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L116

source

# MXNet.mx._broadcast_minus!Method.

_broadcast_minus!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L105

source

# MXNet.mx._broadcast_minusMethod.

_broadcast_minus(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L105

source

# MXNet.mx._broadcast_mod!Method.

_broadcast_mod!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L221

source

# MXNet.mx._broadcast_modMethod.

_broadcast_mod(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L221

source

# MXNet.mx._broadcast_mul!Method.

_broadcast_mul!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L145

source

# MXNet.mx._broadcast_mulMethod.

_broadcast_mul(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:L145

source

# MXNet.mx._broadcast_not_equal!Method.

_broadcast_not_equal!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L63

source

# MXNet.mx._broadcast_not_equalMethod.

_broadcast_not_equal(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_logic.cc:L63

source

# MXNet.mx._broadcast_power!Method.

_broadcast_power!(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L44

source

# MXNet.mx._broadcast_powerMethod.

_broadcast_power(x::NDArray, y::NDArray)

Defined in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc:L44

source

# MXNet.mx._clamp!Method.

_clamp!(x::NDArray, lo::Real, hi::Real)

Defined in src/operator/tensor/matrix_op.cc:L676

source

# MXNet.mx._clampMethod.

_clamp(x::NDArray, lo::Real, hi::Real)

Defined in src/operator/tensor/matrix_op.cc:L676

source

# MXNet.mx._docsigMethod.

Generate docstring from function signature

source

# MXNet.mx._meanMethod.

_mean(x::NDArray, dims)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L83

source

# MXNet.mx._meanMethod.

_mean(x::NDArray, ::Colon)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L83

source

# MXNet.mx._minus!Method.

_minus!(x::NDArray, y::NDArray)

source

# MXNet.mx._minusMethod.

_minus(x::NDArray, y::NDArray)

source

# MXNet.mx._mod!Method.

_mod!(x::NDArray, y::NDArray)

source

# MXNet.mx._modMethod.

_mod(x::NDArray, y::NDArray)

source

# MXNet.mx._mod_scalar!Method.

_mod_scalar!(x::NDArray, y::Real)

source

# MXNet.mx._mod_scalarMethod.

_mod_scalar(x::NDArray, y::Real)

source

# MXNet.mx._nd_maximumMethod.

_nd_maximum(x::NDArray, dims)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L31

source

# MXNet.mx._nd_maximumMethod.

_nd_maximum(x::NDArray, ::Colon)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L31

source

# MXNet.mx._nd_minimumMethod.

_nd_minimum(x::NDArray, dims)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L46

source

# MXNet.mx._nd_minimumMethod.

_nd_minimum(x::NDArray, ::Colon)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L46

source

# MXNet.mx._plus!Method.

_plus!(x::NDArray, y::NDArray)

source

# MXNet.mx._plusMethod.

_plus(x::NDArray, y::NDArray)

source

# MXNet.mx._prodMethod.

_prod(x::NDArray, dims)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L30

source

# MXNet.mx._prodMethod.

_prod(x::NDArray, ::Colon)

Defined in src/operator/tensor/./broadcast_reduce_op.h:L30

source

# MXNet.mx._rmod_scalar!Method.

_rmod_scalar!(x::NDArray, y::Real)

source

# MXNet.mx._rmod_scalarMethod.

_rmod_scalar(x::NDArray, y::Real)

source

# MXNet.mx._sumMethod.

_sum(x::NDArray, dims)

Defined in src/operator/tensor/broadcast_reduce_sum_value.cc:L66

source

# MXNet.mx._sumMethod.

_sum(x::NDArray, ::Colon)

Defined in src/operator/tensor/broadcast_reduce_sum_value.cc:L66

source

# MXNet.mx.@_remapMacro.

@_remap(sig::Expr, imp::Expr)

Creating a function in signature sig with the function implementation imp.

Arguments

  • sig is the function signature. If the function name ends with !, it will invoke the corresponding inplace call.
  • imp is the underlying libmxnet API call

source