Step 2: Create a neural network¶
In this step, you learn how to use NP on Apache MXNet to create neural networks in Gluon. In addition to the np
package that you learned about in the previous step Step 1: Manipulate data with NP on MXNet, you also need to import the neural network modules from gluon
. Gluon includes built-in neural network layers in the following two modules:
mxnet.gluon.nn
: NN module that maintained by the mxnet teammxnet.gluon.contrib.nn
: Experiemental module that is contributed by the community
Use the following commands to import the packages required for this step.
[1]:
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np() # Change MXNet to the numpy-like mode.
Create your neural network’s first layer¶
In this section, you will create a simple neural network with Gluon. One of the simplest network you can create is a single Dense layer or densely- connected layer. A dense layer consists of nodes in the input that are connected to every node in the next layer. Use the following code example to start with a dense layer with five output units.
[2]:
layer = nn.Dense(5)
layer
# output: Dense(-1 -> 5, linear)
[2]:
Dense(-1 -> 5, linear)
In the example above, the output is Dense(-1 -> 5, linear)
. The -1 in the output denotes that the size of the input layer is not specified during initialization.
You can also call the Dense layer with an in_units
parameter if you know the shape of your input unit.
[3]:
layer = nn.Dense(5,in_units=3)
layer
[3]:
Dense(3 -> 5, linear)
In addition to the in_units
param, you can also add an activation function to the layer using the activation
param. The Dense layer implements the operation
Call the Dense layer with an activation
parameter to use an activation function.
[4]:
layer = nn.Dense(5, in_units=3,activation='relu')
Voila! Congratulations on creating a simple neural network. But for most of your use cases, you will need to create a neural network with more than one dense layer or with multiple types of other layers. In addition to the Dense
layer, you can find more layers at mxnet nn layers
So now that you have created a neural network, you are probably wondering how to pass data into your network?
First, you need to initialize the network weights, if you use the default initialization method which draws random values uniformly in the range \([-0.7, 0.7]\). You can see this in the following example.
Note: Initialization is discussed at a little deeper detail in the next notebook
[5]:
layer.initialize()
[03:51:58] /work/mxnet/src/storage/storage.cc:202: Using Pooled (Naive) StorageManager for CPU
Now that you have initialized your network, you can give it data. Passing data through a network is also called a forward pass. You can do a forward pass with random data, shown in the following example. First, you create a (10,3)
shape random input x
and feed the data into the layer to compute the output.
[6]:
x = np.random.uniform(-1,1,(10,3))
layer(x)
[6]:
array([[0.00881556, 0.01138476, 0. , 0. , 0.01936117],
[0. , 0.0035577 , 0.06854778, 0.03361227, 0. ],
[0.05338536, 0.01661206, 0. , 0.00864646, 0. ],
[0. , 0. , 0. , 0.03724934, 0.03653988],
[0. , 0.00657675, 0.00472842, 0. , 0.04593495],
[0. , 0.00795121, 0. , 0. , 0.0595542 ],
[0.02296758, 0.01650022, 0. , 0. , 0.03874438],
[0.02207369, 0.02308735, 0.04558432, 0.01468477, 0. ],
[0. , 0.02254252, 0. , 0. , 0.03834942],
[0.08092358, 0.04224379, 0. , 0. , 0. ]])
The layer produces a (10,5)
shape output from your (10,3)
input.
When you don’t specify the ``in_unit`` parameter, the system automatically infers it during the first time you feed in data during the first forward step after you create and initialize the weights.
[7]:
layer.params
[7]:
{'weight': Parameter (shape=(5, 3), dtype=float32),
'bias': Parameter (shape=(5,), dtype=float32)}
The weights
and bias
can be accessed using the .data()
method.
[8]:
layer.weight.data()
[8]:
array([[ 0.01607367, 0.05928481, -0.0319057 ],
[-0.05814854, 0.01664302, -0.02215988],
[-0.04094896, 0.03231322, 0.05914024],
[ 0.05500493, 0.03504761, 0.05073748],
[ 0.00943237, -0.06525595, -0.04184696]])
Chain layers into a neural network using nn.Sequential¶
Sequential provides a special way of rapidly building networks when when the network architecture follows a common design pattern: the layers look like a stack of pancakes. Many networks follow this pattern: a bunch of layers, one stacked on top of another, where the output of each layer is fed directly to the input to the next layer. To use sequential, simply provide a list of layers (pass in the layers by calling net.add(<Layer goes here!>
). To do this you can use your previous example of
Dense layers and create a 3-layer multi layer perceptron. You can create a sequential block using nn.Sequential()
method and add layers using add()
method.
[9]:
net = nn.Sequential()
net.add(nn.Dense(5,in_units=3,activation='relu'),
nn.Dense(25, activation='relu'), nn.Dense(2))
net
[9]:
Sequential(
(0): Dense(3 -> 5, Activation(relu))
(1): Dense(-1 -> 25, Activation(relu))
(2): Dense(-1 -> 2, linear)
)
The layers are ordered exactly the way you defined your neural network with index starting from 0. You can access the layers by indexing the network using []
.
[10]:
net[1]
[10]:
Dense(-1 -> 25, Activation(relu))
Create a custom neural network architecture flexibly¶
nn.Sequential()
allows you to create your multi-layer neural network with existing layers from gluon.nn
. It also includes a pre-defined forward()
function that sequentially executes added layers. But what if the built-in layers are not sufficient for your needs. If you want to create networks like ResNet which has complex but repeatable components, how do you create such a network?
In gluon, every neural network layer is defined by using a base class nn.Block()
. A Block has one main job - define a forward method that takes some input x and generates an output. A Block can just do something simple like apply an activation function. It can combine multiple layers together in a single block or also combine a bunch of other Blocks together in creative ways to create complex networks like Resnet. In this case, you will construct three Dense layers. The forward()
method
can then invoke the layers in turn to generate its output.
Create a subclass of nn.Block
and implement two methods by using the following code.
__init__
create the layersforward
define the forward function.
[11]:
class Net(nn.Block):
def __init__(self):
super().__init__()
def forward(self, x):
return x
[12]:
class MLP(nn.Block):
def __init__(self):
super().__init__()
self.dense1 = nn.Dense(5,activation='relu')
self.dense2 = nn.Dense(25,activation='relu')
self.dense3 = nn.Dense(2)
def forward(self, x):
layer1 = self.dense1(x)
layer2 = self.dense2(layer1)
layer3 = self.dense3(layer2)
return layer3
net = MLP()
net
[12]:
MLP(
(dense1): Dense(-1 -> 5, Activation(relu))
(dense2): Dense(-1 -> 25, Activation(relu))
(dense3): Dense(-1 -> 2, linear)
)
[13]:
net.dense1.params
[13]:
{'weight': Parameter (shape=(5, -1), dtype=float32),
'bias': Parameter (shape=(5,), dtype=float32)}
Each layer includes parameters that are stored in a Parameter
class. You can access them using the params()
method.
Creating custom layers using Parameters (Blocks API)¶
MXNet includes a Parameter
method to hold your parameters in each layer. You can create custom layers using the Parameter
class to include computation that may otherwise be not included in the built-in layers. For example, for a dense layer, the weights and biases will be created using the Parameter
method. But if you want to add additional computation to the dense layer, you can create it using parameter method.
Instantiate a parameter, e.g weights with a size (5,0)
using the shape
argument.
[14]:
from mxnet.gluon import Parameter
weight = Parameter("custom_parameter_weight",shape=(5,-1))
bias = Parameter("custom_parameter_bias",shape=(5,-1))
weight,bias
[14]:
(Parameter (shape=(5, -1), dtype=<class 'numpy.float32'>),
Parameter (shape=(5, -1), dtype=<class 'numpy.float32'>))
The Parameter
method includes a grad_req
argument that specifies how you want to capture gradients for this Parameter. Under the hood, that lets gluon know that it has to call .attach_grad()
on the underlying array. By default, the gradient is updated everytime the gradient is written to the grad grad_req='write'
.
Now that you know how parameters work, you are ready to create your very own fully-connected custom layer.
To create the custom layers using parameters, you can use the same skeleton with nn.Block
base class. You will create a custom dense layer that takes parameter x and returns computed w*x + b
without any activation function
[15]:
class custom_layer(nn.Block):
def __init__(self, out_units, in_units=0):
super().__init__()
self.weight = Parameter("weight", shape=(in_units,out_units), allow_deferred_init=True)
self.bias = Parameter("bias", shape=(out_units,), allow_deferred_init=True)
def forward(self, x):
return np.dot(x, self.weight.data()) + self.bias.data()
Parameter can be instantiated before the corresponding data is instantiated. For example, when you instantiate a Block but the shapes of each parameter still need to be inferred, the Parameter will wait for the shape to be inferred before allocating memory.
[16]:
dense = custom_layer(3,in_units=5)
dense.initialize()
dense(np.random.uniform(size=(4, 5)))
[16]:
array([[-0.05604633, -0.06238654, 0.02687173],
[-0.02687152, -0.04365591, -0.00518382],
[-0.02849396, -0.09980228, -0.00695815],
[-0.04527343, -0.00275569, -0.01376584]])
Similarly, you can use the following code to implement a famous network called LeNet through nn.Block
using the built-in Dense
layer and using custom_layer
as the last layer
[17]:
class LeNet(nn.Block):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2D(channels=6, kernel_size=3, activation='relu')
self.pool1 = nn.MaxPool2D(pool_size=2, strides=2)
self.conv2 = nn.Conv2D(channels=16, kernel_size=3, activation='relu')
self.pool2 = nn.MaxPool2D(pool_size=2, strides=2)
self.dense1 = nn.Dense(120, activation="relu")
self.dense2 = nn.Dense(84, activation="relu")
self.dense3 = nn.Dense(10)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
return x
lenet = LeNet()
[18]:
class LeNet_custom(nn.Block):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2D(channels=6, kernel_size=3, activation='relu')
self.pool1 = nn.MaxPool2D(pool_size=2, strides=2)
self.conv2 = nn.Conv2D(channels=16, kernel_size=3, activation='relu')
self.pool2 = nn.MaxPool2D(pool_size=2, strides=2)
self.dense1 = nn.Dense(120, activation="relu")
self.dense2 = nn.Dense(84, activation="relu")
self.dense3 = custom_layer(10,84)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
return x
lenet_custom = LeNet_custom()
[19]:
image_data = np.random.uniform(-1,1, (1,1,28,28))
lenet.initialize()
lenet_custom.initialize()
print("Lenet:")
print(lenet(image_data))
print("Custom Lenet:")
print(lenet_custom(image_data))
Lenet:
[[-0.00081668 -0.00340701 0.00199039 -0.00121501 -0.00063157 0.00081476
-0.0011025 -0.00216652 -0.00015363 -0.00110007]]
Custom Lenet:
[[-0.02656163 0.04184292 -0.00254037 -0.06494093 -0.00952166 -0.01921579
0.05423243 -0.02774546 0.06823301 0.00313227]]
You can use .data
method to access the weights and bias of a particular layer. For example, the following accesses the first layer’s weight and sixth layer’s bias.
[20]:
lenet.conv1.weight.data().shape, lenet.dense1.bias.data().shape
[20]:
((6, 1, 3, 3), (120,))
Using predefined (pretrained) architectures¶
Till now, you have seen how to create your own neural network architectures. But what if you want to replicate or baseline your dataset using some of the common models in computer visions or natural language processing (NLP). Gluon includes common architectures that you can directly use. The Gluon Model Zoo provides a collection of off-the-shelf models e.g. RESNET, BERT etc. These architectures are found at:
[21]:
from mxnet.gluon import model_zoo
net = model_zoo.vision.resnet50_v2(pretrained=True)
net.hybridize()
dummy_input = np.ones(shape=(1,3,224,224))
output = net(dummy_input)
output.shape
Downloading /home/jenkins_slave/.mxnet/models/resnet50_v2-ecdde353.zip00383814-e655-4621-a110-5ffefe3eb69c from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet50_v2-ecdde353.zip...
[21]:
(1, 1000)
Deciding the paradigm for your network¶
In MXNet, Gluon API (Imperative programming paradigm) provides a user friendly way for quick prototyping, easy debugging and natural control flow for people familiar with python programming.
However, at the backend, MXNET can also convert the network using Symbolic or Declarative programming into static graphs with low level optimizations on operators. However, static graphs are less flexible because any logic must be encoded into the graph as special operators like scan, while_loop and cond. It’s also hard to debug.
So how can you make use of symbolic programming while getting the flexibility of imperative programming to quickly prototype and debug?
Enter HybridBlock
HybridBlocks can run in a fully imperatively way where you define their computation with real functions acting on real inputs. But they’re also capable of running symbolically, acting on placeholders. Gluon hides most of this under the hood so you will only need to know how it works when you want to write your own layers.
[22]:
net_hybrid_seq = nn.HybridSequential()
net_hybrid_seq.add(nn.Dense(5,in_units=3,activation='relu'),
nn.Dense(25, activation='relu'), nn.Dense(2) )
net_hybrid_seq
[22]:
HybridSequential(
(0): Dense(3 -> 5, Activation(relu))
(1): Dense(-1 -> 25, Activation(relu))
(2): Dense(-1 -> 2, linear)
)
To compile and optimize HybridSequential
, you can call its hybridize
method.
[23]:
net_hybrid_seq.hybridize()
Creating custom layers using Parameters (HybridBlocks API)¶
When you instantiated your custom layer, you specified the input dimension in_units
that initializes the weights with the shape specified by in_units
and out_units
. If you leave the shape of in_unit
as unknown, you defer the shape to the first forward pass. For the custom layer, you define the infer_shape()
method and let the shape be inferred at runtime.
[24]:
class CustomLayer(nn.HybridBlock):
def __init__(self, out_units, in_units=-1):
super().__init__()
self.weight = Parameter("weight", shape=(in_units, out_units), allow_deferred_init=True)
self.bias = Parameter("bias", shape=(out_units,), allow_deferred_init=True)
def forward(self, x):
print(self.weight.shape, self.bias.shape)
return np.dot(x, self.weight.data()) + self.bias.data()
def infer_shape(self, x):
print(self.weight.shape,x.shape)
self.weight.shape = (x.shape[-1],self.weight.shape[1])
dense = CustomLayer(3)
dense.initialize()
dense(np.random.uniform(size=(4, 5)))
/work/mxnet/python/mxnet/util.py:755: UserWarning: Parameter 'weight' is already initialized, ignoring. Set force_reinit=True to re-initialize.
return func(*args, **kwargs)
/work/mxnet/python/mxnet/util.py:755: UserWarning: Parameter 'bias' is already initialized, ignoring. Set force_reinit=True to re-initialize.
return func(*args, **kwargs)
[24]:
array([[-0.07053316, -0.07457963, 0.01166525],
[-0.04170407, -0.07482161, 0.00179428],
[-0.07503258, 0.00660181, -0.01401043],
[-0.02333996, -0.06775613, 0.01459978]])
Performance¶
To get a sense of the speedup from hybridizing, you can compare the performance before and after hybridizing by measuring the time it takes to make 1000 forward passes through the network.
[25]:
from time import time
def benchmark(net, x):
y = net(x)
start = time()
for i in range(1,1000):
y = net(x)
return time() - start
x_bench = np.random.normal(size=(1,512))
net_hybrid_seq = nn.HybridSequential()
net_hybrid_seq.add(nn.Dense(256,activation='relu'),
nn.Dense(128, activation='relu'),
nn.Dense(2))
net_hybrid_seq.initialize()
print('Before hybridizing: %.4f sec'%(benchmark(net_hybrid_seq, x_bench)))
net_hybrid_seq.hybridize()
print('After hybridizing: %.4f sec'%(benchmark(net_hybrid_seq, x_bench)))
Before hybridizing: 0.6034 sec
After hybridizing: 0.2799 sec
Peeling back another layer, you also have a HybridBlock
which is the hybrid version of the Block
API.
Similar to the Blocks
API, you define a forward
function for HybridBlock
that takes an input x
. MXNet takes care of hybridizing the model at the backend so you don’t have to make changes to your code to convert it to a symbolic paradigm.
[26]:
from mxnet.gluon import HybridBlock
class MLP_Hybrid(HybridBlock):
def __init__(self):
super().__init__()
self.dense1 = nn.Dense(256,activation='relu')
self.dense2 = nn.Dense(128,activation='relu')
self.dense3 = nn.Dense(2)
def forward(self, x):
layer1 = self.dense1(x)
layer2 = self.dense2(layer1)
layer3 = self.dense3(layer2)
return layer3
net_hybrid = MLP_Hybrid()
net_hybrid.initialize()
print('Before hybridizing: %.4f sec'%(benchmark(net_hybrid, x_bench)))
net_hybrid.hybridize()
print('After hybridizing: %.4f sec'%(benchmark(net_hybrid, x_bench)))
Before hybridizing: 0.5799 sec
After hybridizing: 0.2603 sec
Given a HybridBlock whose forward computation consists of going through other HybridBlocks, you can compile that section of the network by calling the HybridBlocks .hybridize()
method.
All of MXNet’s predefined layers are HybridBlocks. This means that any network consisting entirely of predefined MXNet layers can be compiled and run at much faster speeds by calling .hybridize()
.
Saving and Loading your models¶
The Blocks API also includes saving your models during and after training so that you can host the model for inference or avoid training the model again from scratch. Another reason would be to train your model using one language (like Python that has a lot of tools for training) and run inference using a different language.
There are two ways to save your model in MXNet. 1. Save/load the model weights/parameters only 2. Save/load the model weights/parameters and the architectures
1. Save/load the model weights/parameters only¶
You can use save_parameters
and load_parameters
method to save and load the model weights. Take your simplest model layer
and save your parameters first. The model parameters are the params that you save after you train your model.
[27]:
file_name = 'layer.params'
layer.save_parameters(file_name)
And now load this model again. To load the parameters into a model, you will first have to build the model. To do this, you will need to create a simple function to build it.
[28]:
def build_model():
layer = nn.Dense(5, in_units=3,activation='relu')
return layer
layer_new = build_model()
[29]:
layer_new.load_parameters('layer.params')
Note: The save_parameters
and load_parameters
method is used for models that use a Block
method instead of HybridBlock
method to build the model. These models may have complex architectures where the model architectures may change during execution. E.g. if you have a model that uses an if-else conditional statement to choose between two different architectures.
2. Save/load the model weights/parameters and the architectures¶
For models that use the HybridBlock, the model architecture stays static and do no change during execution. Therefore both model parameters AND architecture can be saved and loaded using export
, imports
methods.
Now look at your MLP_Hybrid
model and export the model using the export
function. The export function will export the model architecture into a .json
file and model parameters into a .params
file.
[30]:
net_hybrid.export('MLP_hybrid')
[30]:
('MLP_hybrid-symbol.json', 'MLP_hybrid-0000.params')
[31]:
net_hybrid.export('MLP_hybrid')
[31]:
('MLP_hybrid-symbol.json', 'MLP_hybrid-0000.params')
Similarly, to load this model back, you can use gluon.nn.SymbolBlock
. To demonstrate that, load the network serialized above.
[32]:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
net_loaded = nn.SymbolBlock.imports("MLP_hybrid-symbol.json",
['data'], "MLP_hybrid-0000.params",
device=None)
[33]:
net_loaded(x_bench)
[33]:
array([[ 0.13653663, -0.07247495]])
Visualizing your models¶
In MXNet, the Block.Summary()
method allows you to view the block’s shape arguments and view the block’s parameters. When you combine multiple blocks into a model, the summary()
applied on the model allows you to view each block’s summary, the total parameters, and the order of the blocks within the model. To do this the Block.summary()
method requires one forward pass of the data, through your network, in order to create the graph necessary for capturing the corresponding shapes and
parameters. Additionally, this method should be called before the hybridize method, since the hybridize method converts the graph into a symbolic one, potentially changing the operations for optimal computation.
Look at the following examples
layer: our single layer network
Lenet: a non-hybridized LeNet network
net_Hybrid: our MLP Hybrid network
[34]:
layer.summary(x)
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (10, 3) 0
Activation-1 (10, 5) 0
Dense-2 (10, 5) 20
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 20
Trainable params: 20
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 20
--------------------------------------------------------------------------------
[35]:
lenet.summary(image_data)
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (1, 1, 28, 28) 0
Activation-1 (1, 6, 26, 26) 0
Conv2D-2 (1, 6, 26, 26) 60
MaxPool2D-3 (1, 6, 13, 13) 0
Activation-4 (1, 16, 11, 11) 0
Conv2D-5 (1, 16, 11, 11) 880
MaxPool2D-6 (1, 16, 5, 5) 0
Activation-7 (1, 120) 0
Dense-8 (1, 120) 48120
Activation-9 (1, 84) 0
Dense-10 (1, 84) 10164
Dense-11 (1, 10) 850
LeNet-12 (1, 10) 0
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 60074
Trainable params: 60074
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 60074
--------------------------------------------------------------------------------
You are able to print the summaries of the two networks layer
and lenet
easily since you didn’t hybridize the two networks. However, the last network net_Hybrid
was hybridized above and throws an AssertionError
if you try net_Hybrid.summary(x_bench)
. To print the summary for net_Hybrid
, call another instance of the same network and instantiate it for our summary and then hybridize it
[36]:
net_hybrid_summary = MLP_Hybrid()
net_hybrid_summary.initialize()
net_hybrid_summary.summary(x_bench)
net_hybrid_summary.hybridize()
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (1, 512) 0
Activation-1 (1, 256) 0
Dense-2 (1, 256) 131328
Activation-3 (1, 128) 0
Dense-4 (1, 128) 32896
Dense-5 (1, 2) 258
MLP_Hybrid-6 (1, 2) 0
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 164482
Trainable params: 164482
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 164482
--------------------------------------------------------------------------------
Next steps:¶
Now that you have created a neural network, learn how to automatically compute the gradients in Step 3: Automatic differentiation with autograd.