PyTorch vs Apache MXNet¶
PyTorch is a popular deep learning framework due to its easy-to-understand API and its completely imperative approach. Apache MXNet includes the Gluon API which gives you the simplicity and flexibility of PyTorch and allows you to hybridize your network to leverage performance optimizations of the symbolic graph. As of April 2019, NVidia performance benchmarks show that Apache MXNet outperforms PyTorch by ~77% on training ResNet-50: 10,925 images per second vs. 6,175.
In the next 10 minutes, we’ll do a quick comparison between the two frameworks and show how small the learning curve can be when switching from PyTorch to Apache MXNet.
Installation¶
PyTorch uses conda for installation by default, for example:
[ ]:
# !conda install pytorch-cpu -c pytorch
For MXNet we use pip:
[ ]:
# !pip install mxnet
To install Apache MXNet with GPU support, you need to specify CUDA version. For example, the snippet below will install Apache MXNet with CUDA 9.2 support:
[ ]:
# !pip install mxnet-cuda92
Data manipulation¶
Both PyTorch and Apache MXNet relies on multidimensional matrices as a data sources. While PyTorch follows Torch’s naming convention and refers to multidimensional matrices as “tensors”, Apache MXNet follows NumPy’s conventions and refers to them as “NDArrays”.
In the code snippets below, we create a two-dimensional matrix where each element is initialized to 1. We show how to add 1 to each element of matrices and print the results.
PyTorch:
[ ]:
import torch
x = torch.ones(5,3)
y = x + 1
y
MXNet:
[ ]:
from mxnet import nd
x = nd.ones((5,3))
y = x + 1
y
The main difference apart from the package name is that the MXNet’s shape input parameter needs to be passed as a tuple enclosed in parentheses as in NumPy.
Both frameworks support multiple functions to create and manipulate tensors / NDArrays. You can find more of them in the documentation.
Model training¶
After covering the basics of data creation and manipulation, let’s dive deep and compare how model training is done in both frameworks. In order to do so, we are going to solve image classification task on MNIST data set using Multilayer Perceptron (MLP) in both frameworks. We divide the task in 4 steps.
1. Read data¶
The first step is to obtain the data. We download the MNIST data set from the web and load it into memory so that we can read batches one by one.
PyTorch:
[ ]:
from torchvision import datasets, transforms
trans = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.13,), (0.31,))])
pt_train_data = torch.utils.data.DataLoader(datasets.MNIST(
root='.', train=True, download=True, transform=trans),
batch_size=128, shuffle=True, num_workers=4)
MXNet:
[ ]:
from mxnet import gluon
from mxnet.gluon.data.vision import datasets, transforms
trans = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(0.13, 0.31)])
mx_train_data = gluon.data.DataLoader(
datasets.MNIST(train=True).transform_first(trans),
batch_size=128, shuffle=True, num_workers=4)
Both frameworks allows you to download MNIST data set from their sources and specify that only training part of the data set is required.
The main difference between the code snippets is that MXNet uses transform_first method to indicate that the data transformation is done on the first element of the data batch, the MNIST picture, rather than the second element, the label.
2. Creating the model¶
Below we define a Multilayer Perceptron (MLP) with a single hidden layer and 10 units in the output layer.
PyTorch:
[ ]:
import torch.nn as pt_nn
pt_net = pt_nn.Sequential(
pt_nn.Linear(28*28, 256),
pt_nn.ReLU(),
pt_nn.Linear(256, 10))
MXNet:
[ ]:
import mxnet.gluon.nn as mx_nn
mx_net = mx_nn.Sequential()
mx_net.add(mx_nn.Dense(256, activation='relu'),
mx_nn.Dense(10))
mx_net.initialize()
We used the Sequential container to stack layers one after the other in order to construct the neural network. Apache MXNet differs from PyTorch in the following ways:
In PyTorch you have to specify the input size as the first argument of the
Linear
object. Apache MXNet provides an extra flexibility to network structure by automatically inferring the input size after the first forward pass.In Apache MXNet you can specify activation functions directly in fully connected and convolutional layers.
After the model structure is defined, Apache MXNet requires you to explicitly call the model initialization function.
With a Sequential block, layers are executed one after the other. To have a different execution model, with PyTorch you can inherit from nn.Module
and then customize how the .forward()
function is executed. Similarly, in Apache MXNet you can inherit from nn.Block to achieve similar results.
3. Loss function and optimization algorithm¶
The next step is to define the loss function and pick an optimization algorithm. Both PyTorch and Apache MXNet provide multiple options to chose from, and for our particular case we are going to use the cross-entropy loss function and the Stochastic Gradient Descent (SGD) optimization algorithm.
PyTorch:
[ ]:
pt_loss_fn = pt_nn.CrossEntropyLoss()
pt_trainer = torch.optim.SGD(pt_net.parameters(), lr=0.1)
MXNet:
[ ]:
mx_loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
mx_trainer = gluon.Trainer(mx_net.collect_params(),
'sgd', {'learning_rate': 0.1})
The code difference between frameworks is small. The main difference is that in Apache MXNet we use Trainer class, which accepts optimization algorithm as an argument. We also use .collect_params() method to get parameters of the network.
4. Training¶
Finally, we implement the training algorithm. Note that the results for each run may vary because the weights will get different initialization values and the data will be read in a different order due to shuffling.
PyTorch:
[ ]:
import time
for epoch in range(5):
total_loss = .0
tic = time.time()
for X, y in pt_train_data:
pt_trainer.zero_grad()
loss = pt_loss_fn(pt_net(X.view(-1, 28*28)), y)
loss.backward()
pt_trainer.step()
total_loss += loss.mean()
print('epoch %d, avg loss %.4f, time %.2f' % (
epoch, total_loss/len(pt_train_data), time.time()-tic))
MXNet:
[ ]:
from mxnet import autograd
for epoch in range(5):
total_loss = .0
tic = time.time()
for X, y in mx_train_data:
with autograd.record():
loss = mx_loss_fn(mx_net(X), y)
loss.backward()
mx_trainer.step(batch_size=128)
total_loss += loss.mean().asscalar()
print('epoch %d, avg loss %.4f, time %.2f' % (
epoch, total_loss/len(mx_train_data), time.time()-tic))
Some of the differences in Apache MXNet when compared to PyTorch are as follows:
In Apache MXNet, you don’t need to flatten the 4-D input into 2-D when feeding the data into forward pass.
In Apache MXNet, you need to perform the calculation within the autograd.record() scope so that it can be automatically differentiated in the backward pass.
It is not necessary to clear the gradient every time as with PyTorch’s
trainer.zero_grad()
because by default the new gradient is written in, not accumulated.You need to specify the update step size (usually batch size) when performing step() on the trainer.
You need to call .asscalar() to turn a multidimensional array into a scalar.
In this sample, Apache MXNet is twice as fast as PyTorch. Though you need to be cautious with such toy comparisons.
Conclusion¶
As we saw above, Apache MXNet Gluon API and PyTorch have many similarities. The main difference lies in terminology (Tensor vs. NDArray) and behavior of accumulating gradients: gradients are accumulated in PyTorch and overwritten in Apache MXNet. The rest of the code is very similar, and it is quite straightforward to move code from one framework to the other.
Recommended Next Steps¶
While Apache MXNet Gluon API is very similar to PyTorch, there are some extra functionality that can make your code even faster.
Check out Hybridize tutorial to learn how to write imperative code which can be converted to symbolic one.
Also, check out how to extend Apache MXNet with your own custom layers.
Appendix¶
Below you can find a detailed comparison of various PyTorch functions and their equivalent in Gluon API of Apache MXNet.
Tensor operation¶
Here is the list of function names in PyTorch Tensor that are different from Apache MXNet NDArray.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Element-wise inverse cosine |
|
|
Batch Matrix product and accumulation |
|
|
Element-wise division of t1, t2, multiply v, and add t |
|
|
Matrix product and accumulation |
|
|
Outer-product of two vector add a matrix |
|
Not available |
Element-wise applies function |
|
Not available, but there is |
Element-wise inverse sine |
|
|
Element-wise inverse tangent |
|
|
Tangent of two tensor |
|
Not available |
batch matrix product |
|
|
Draws a sample from bernoulli distribution |
|
Not available |
Fills a tensor with number drawn from Cauchy distribution |
|
Not available |
Splits a tensor in a given dim |
|
|
Limits the values of a tensor to between min and max |
|
|
Returns a copy of the tensor |
|
|
Cross product |
|
Not available |
Cumulative product along an axis |
|
Not available |
Cumulative sum along an axis |
|
Not available |
Address of the first element |
|
Not available |
Creates a diagonal tensor |
|
Not available |
Computes norm of a tensor |
|
|
Computes Gauss error function |
|
Not available |
Broadcasts/Expands tensor to new shape |
|
|
Fills a tensor with samples drawn from exponential distribution |
|
|
Element-wise mod |
|
|
Fractional portion of a tensor |
|
|
Gathers values along an axis specified by dim |
|
|
Solves least square & least norm |
|
Not available |
Draws from geometirc distribution |
|
Not available |
Device context of a tensor |
|
|
Repeats tensor |
|
|
Data type of a tensor |
|
|
Scatter |
|
|
Returns the shape of a tensor |
|
|
Number of elements in a tensor |
|
|
Returns this tensor as a NumPy ndarray |
|
|
Eigendecomposition for symmetric matrix |
|
|
Transpose |
|
|
Sample uniformly |
|
|
Inserts a new dimesion |
|
|
Reshape |
|
|
Veiw as a specified tensor |
|
|
Returns a copy of the tensor after casting to a specified type |
|
|
Copies the value of one tensor to another |
|
|
Returns a zero tensor with specified shape |
|
|
Returns a one tensor with specified shape |
|
|
Returns a Tensor filled with the scalar value 1, with the same size as input |
|
|
Functional¶
GPU¶
Just like Tensor, MXNet NDArray can be copied to and operated on GPU. This is done by specifying context.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Copy to GPU |
|
|
Convert to numpy array |
|
|
Context scope |
|
|
Cross-device¶
Just like Tensor, MXNet NDArray can be copied across multiple GPUs.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Copy from GPU 0 to GPU 1 |
|
|
Copy Tensor/NDArray on different GPUs |
|
|
Autograd¶
Variable wrapper vs autograd scope¶
Autograd package of PyTorch/MXNet enables automatic differentiation of Tensor/NDArray.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Recording computation |
|
|
Scope override (pause, train_mode, predict_mode)¶
Some operators (Dropout, BatchNorm, etc) behave differently in training and making predictions. This can be controlled with train_mode
and predict_mode
scope in MXNet. Pause scope is for code that does not need gradients to be calculated.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Scope override |
Not available |
|
Batch-end synchronization is needed¶
Apache MXNet uses lazy evaluation to achieve superior performance. The Python thread just pushes the operations into the backend engine and then returns. In training phase batch-end synchronization is needed, e.g, asnumpy()
, wait_to_read()
, metric.update(...)
.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Batch-end synchronization |
Not available |
|
PyTorch module and Gluon blocks¶
For new block definition, gluon needs name_scope¶
name_scope
coerces Gluon to give each parameter an appropriate name, indicating which model it belongs to.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
New block definition |
|
|
Parameter and Initializer¶
When creating new layers in PyTorch, you do not need to specify its parameter initializer, and different layers have different default initializer. When you create new layers in Gluon API, you can specify its initializer or just leave it none. The parameters will finish initializing after calling net.initialize(<init method>)
and all parameters will be initialized in init method
except those layers whose initializer specified.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Get all parameters |
|
|
Initialize network |
Not Available |
|
Specify layer initializer |
|
|
Usage of existing blocks look alike¶
Function |
PyTorch |
MXNet Gluon |
---|---|---|
Usage of existing blocks |
|
|
HybridBlock can be hybridized, and allows partial-shape info¶
HybridBlock supports forwarding with both Symbol and NDArray. After hybridized, HybridBlock will create a symbolic graph representing the forward computation and cache it. Most of the built-in blocks (Dense, Conv2D, MaxPool2D, BatchNorm, etc.) are HybridBlocks.
Instead of explicitly declaring the number of inputs to a layer, we can simply state the number of outputs. The shape will be inferred on the fly once the network is provided with some input.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
partial-shape hybridized |
Not Available |
|
SymbolBlock¶
SymbolBlock can construct block from symbol. This is useful for using pre-trained models as feature extractors.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
SymbolBlock |
Not Available |
|
PyTorch optimizer vs Gluon Trainer¶
For Gluon API calling zero_grad is not necessary most of the time¶
zero_grad
in optimizer (PyTorch) or Trainer (Gluon API) clears the gradients of all parameters. In Gluon API, there is no need to clear the gradients every batch if grad_req = 'write'
(default).
Function |
PyTorch |
MXNet Gluon |
---|---|---|
clear the gradients |
|
|
Multi-GPU training¶
Function |
PyTorch |
MXNet Gluon |
---|---|---|
data parallelism |
|
|
Distributed training¶
Function |
Pytorch |
MXNet Gluon |
---|---|---|
distributed data parallelism |
|
|
Monitoring¶
Apache MXNet has pre-defined metrics¶
Gluon provide several predefined metrics which can online evaluate the performance of a learned model.
Function |
PyTorch |
MXNet Gluon |
---|---|---|
metric |
Not available |
|
Data visualization¶
TensorboardX (PyTorch) and MXBoard (MXNet) can be used to visualize your network and plot quantitative metrics about the execution of your graph.
PyTorch |
MXNet |
---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
I/O and deploy¶
Data loading¶
Dataset
and DataLoader
are the basic components for loading data.
Class |
PyTorch |
MXNet Gluon |
---|---|---|
Dataset holding arrays |
|
|
Data loader |
|
|
Sequentially applied sampler |
|
|
Random order sampler |
|
|
Some commonly used datasets for computer vision are provided in mx.gluon.data.vision
package.
Class |
PyTorch |
MXNet Gluon |
---|---|---|
MNIST handwritten digits dataset. |
|
|
CIFAR10 Dataset. |
|
|
CIFAR100 Dataset. |
|
|
A generic data loader where the images are arranged in folders. |
|
|
Serialization¶
Serialization and deserialization are achieved by calling save_parameters
and load_parameters
.
Class |
PyTorch |
MXNet Gluon |
---|---|---|
Save model parameters |
|
|
Load parameters |
|
|