How to Use MXNet As an (Almost) Full-function Torch Front End

This topic demonstrates how to use MXNet as a front end to two of Torch’s major functionalities:

  • Call Torch’s tensor mathematical functions with MXNet.NDArray
  • Embed Torch’s neural network modules (layers) into MXNet’s symbolic graph

Compile with Torch

  • Install Torch using the official guide.
    • If you haven’t already done so, copy make/config.mk (Linux) or make/osx.mk (Mac) into the MXNet root folder as config.mk. In config.mk uncomment the lines TORCH_PATH = $(HOME)/torch and MXNET_PLUGINS += plugin/torch/torch.mk.
    • By default, Torch should be installed in your home folder (so TORCH_PATH = $(HOME)/torch). Modify TORCH_PATH to point to your torch installation, if necessary.
  • Run make clean && make to build MXNet with Torch support.

Tensor Mathematics

The mxnet.th module supports calling Torch’s tensor mathematical functions with mxnet.nd.NDArray. See complete code:

    import mxnet as mx
    x = mx.th.randn(2, 2, ctx=mx.cpu(0))
    print x.asnumpy()
    y = mx.th.abs(x)
    print y.asnumpy()

    x = mx.th.randn(2, 2, ctx=mx.cpu(0))
    print x.asnumpy()
    mx.th.abs(x, x) # in-place
    print x.asnumpy()

For help, use the help(mx.th) command.

We’ve added support for most common functions listed on Torch’s documentation page. If you find that the function you need is not supported, you can easily register it in mxnet_root/plugin/torch/torch_function.cc by using the existing registrations as examples.

Torch Modules (Layers)

MXNet supports Torch’s neural network modules through themxnet.symbol.TorchModule symbol. For example, the following code defines a three-layer DNN for classifying MNIST digits (full code):

    data = mx.symbol.Variable('data')
    fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
    act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
    fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
    act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
    fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
    mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')

Let’s break it down. First data = mx.symbol.Variable('data') defines a Variable as a placeholder for input. Then, it’s fed through Torch’s nn modules with: fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1'). To use Torch’s criterion as loss functions, you can replace the last line with:

    logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
    # Torch's label starts from 1
    label = mx.symbol.Variable('softmax_label') + 1
    mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')

The input to the nn module is named data_i for i = 0 ... num_data-1. lua_string is a single Lua statement that creates the module object. For Torch’s built-in module, this is simply nn.module_name(arguments). If you are using a custom module, place it in a .lua script file and load it with require 'module_file.lua' if your script returns a torch.nn object, or (require 'module_file.lua')() if your script returns a torch.nn class.