Examine forward results with hooks
Examine forward results with hooks
There are currently three ways to register a function in an MXNet Gluon Block for execution:
- before forwardvia register_forward_pre_hook
- after forwardvia register_forward_hook
- as a callback via register_op_hook
Pre-forward hook
To register a hook prior to forward execution, the requirement is that the registered operation should not modify the input or output. For example: hook(block, input) -> None. This is useful to get a summary before execution.
import mxnet as mx
from mxnet.gluon import nn
block = nn.Dense(10)
block.initialize()
print("{}".format(block))
# Dense(None -> 10, linear)
def pre_hook(block, input) -> None:  # notice it has two arguments, one block and one input
    print("{}".format(block))
    return
# register
pre_handle = block.register_forward_pre_hook(pre_hook)
input = mx.nd.ones((3, 5))
print(block(input))
# Dense(None -> 10, linear)
# [[ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]
# [ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]
# [ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]]
# <NDArray 3x10 @cpu(0)>
We can detach a hook from a block:
pre_handle.detach()
print(block(input))
# [[ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]
# [ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]
# [ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]]
# <NDArray 3x10 @cpu(0)>
Notice Dense(None -> 10, linear) is not displayed anymore.
Post-forward hook
Registering a hook after forward execution is very similar to pre-forward hook (as explained above) with the difference that the hook signature should be hook(block, input, output) -> None where hook should not modify the input and output. Continuing from the above example:
def post_hook(block, intput, output) -> None:
    print("{}".format(block))
    return
post_handle = block.register_forward_hook(post_hook)
print(block(input))
# Dense(5 -> 10, linear)
# [[ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]
# [ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]
# [ 0.11254273  0.11162187  0.02200389 -0.04842059  0.09531345  0.00880495
#  -0.07610667  0.1562067   0.14192852  0.04463106]]
# <NDArray 3x10 @cpu(0)>
Notice the difference between pre_hook and post_hook results due to shape inference after forward is done executing.
Callback hook
We can register a callback monitor to monitor all operators that are called by the HybridBlock after hybridization with register_op_hook(callback, monitor_all=False) where the callback signature should be:
callback(node_name: str,  opr_name: str, arr: NDArray) -> None
where node_name is the name of the tensor being inspected (str), opr_name is the name of the operator producing or consuming that tensor (str) and arr the tensor being inspected (NDArray).
import mxnet as mx
from mxnet.gluon import nn
def mon_callback(node_name, opr_name, arr):
    print("{}".format(node_name))
    print("{}".format(opr_name))
    return
model = nn.HybridSequential(prefix="dense_")
with model.name_scope():
     model.add(mx.gluon.nn.Dense(2))
model.initialize()
model.hybridize()
model.register_op_hook(mon_callback, monitor_all=True)
print(model(mx.nd.ones((2, 3, 4))))
# b'dense_dense0_fwd_data'
# b'FullyConnected'
# b'dense_dense0_fwd_weight'
# b'FullyConnected'
# b'dense_dense0_fwd_bias'
# b'FullyConnected'
# b'dense_dense0_fwd_output'
# b'FullyConnected'
# [[-0.05979988 -0.16349721]
#  [-0.05979988 -0.16349721]]
# <NDArray 2x2 @cpu(0)>
Setting monitor_all=False will print only the output:
`# b'dense_dense0_fwd_output'`
`# b'FullyConnected'``
# [[-0.05979988 -0.16349721]
#  [-0.05979988 -0.16349721]]
# <NDArray 2x2 @cpu(0)`
Note that to get the internal operator node names, one can use model.collect_params().items().
