How to Use MXNet As an (Almost) Full-function Torch Front End¶
This topic demonstrates how to use MXNet as a front end to two of Torch’s major functionalities:
- Call Torch’s tensor mathematical functions with MXNet.NDArray
- Embed Torch’s neural network modules (layers) into MXNet’s symbolic graph
Compile with Torch¶
- Install Torch using the official guide.
- If you haven’t already done so, copy
make/config.mk
(Linux) ormake/osx.mk
(Mac) into the MXNet root folder asconfig.mk
. Inconfig.mk
uncomment the linesTORCH_PATH = $(HOME)/torch
andMXNET_PLUGINS += plugin/torch/torch.mk
. - By default, Torch should be installed in your home folder (so
TORCH_PATH = $(HOME)/torch
). Modify TORCH_PATH to point to your torch installation, if necessary.
- If you haven’t already done so, copy
- Run
make clean && make
to build MXNet with Torch support.
Tensor Mathematics¶
The mxnet.th module supports calling Torch’s tensor mathematical functions with mxnet.nd.NDArray. See complete code:
import mxnet as mx
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
y = mx.th.abs(x)
print y.asnumpy()
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
mx.th.abs(x, x) # in-place
print x.asnumpy()
For help, use the help(mx.th)
command.
We’ve added support for most common functions listed on Torch’s documentation page.
If you find that the function you need is not supported, you can easily register it in mxnet_root/plugin/torch/torch_function.cc
by using the existing registrations as examples.
Torch Modules (Layers)¶
MXNet supports Torch’s neural network modules through themxnet.symbol.TorchModule
symbol.
For example, the following code defines a three-layer DNN for classifying MNIST digits (full code):
data = mx.symbol.Variable('data')
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
Let’s break it down. First data = mx.symbol.Variable('data')
defines a Variable as a placeholder for input.
Then, it’s fed through Torch’s nn modules with:
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
.
To use Torch’s criterion as loss functions, you can replace the last line with:
logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
# Torch's label starts from 1
label = mx.symbol.Variable('softmax_label') + 1
mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')
The input to the nn module is named data_i for i = 0 ... num_data-1. lua_string
is a single Lua statement that creates the module object.
For Torch’s built-in module, this is simply nn.module_name(arguments)
.
If you are using a custom module, place it in a .lua script file and load it with require 'module_file.lua'
if your script returns a torch.nn object, or (require 'module_file.lua')()
if your script returns a torch.nn class.