Pytorch's forward hooks
8 July 2019by Keyur Paralkar
Pytorch’s hooks
If you want to inspect or modify the outputs or grad_outputs of intermediate layers or final layers then hooks are the perfect way to handle this situation. Hooks are the function which is registered on a module or Tensor.
There are two types of hooks:
- Forward hook: Executed when a forward call is executed.
- Backward hook: Executed when a backward call is executed.
In this article, I am going to implement a forward hook which helps to print output sizes of intermediate layers.
Printing layers sizes of NN using hooks
Creating a base network
Importing the necessary modules required
import torch
import torch.nn as nn
import numpy as np
The below code is our base network which consists of three conv layers and two linear layers.
class baseNetwork(nn.Module):
def __init__(self):
super(baseNetwork,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1),
nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=1),
nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1)
)
self.linear = nn.Sequential(
nn.Linear(in_features=86436,out_features=1024),
nn.Linear(in_features=1024,out_features=2)
)
def forward(self,x):
x = self.conv(x)
x = x.view(x.size(1),-1)
x = self.linear(x)
return x
model = baseNetwork()
device = torch.device('cuda')
model = model.to(device)
Implementing hooks call for model
We are going to apply forward hooks to the conv and linear layers. These hooks call will get registered when a forward function is executed. Here I have created a class Hook which will help us to implement above functionality.
class Hook():
def __init__(self,module,name_seq_layer="all"):
self.handle = []
self.module = module
self.seq_name = name_seq_layer
def applyHooks(self):
if(self.seq_name=="all"):
for name,module in enumerate(model.children()):
for i,x in enumerate(module):
if("Conv2d'>" in str(type(x)).split('.')):
opType = "Conv2d"
elif("ReLU'>" in str(type(x)).split('.')):
opType="ReLU"
elif("MaxPool2d >" in str(type(x)).split('.')):
opType="MaxPool2d"
else:
opType="Linear"
self.handle.append(x.register_forward_hook(self.layerSizeDisp(i,opType)))
else:
for name,module in enumerate(self.module.named_modules()):
if(str(module[0])==self.seq_name):
for i,x in enumerate(module[1]):
if("Conv2d'>" in str(type(x)).split('.')):
opType = "Conv2d"
elif("ReLU'>" in str(type(x)).split('.')):
opType="ReLU"
elif("MaxPool2d >" in str(type(x)).split('.')):
opType="MaxPool2d"
else:
opType="Linear"
# <class 'torch.nn.modules.conv.Conv2d'>
self.handle.append(x.register_forward_hook(self.layerSizeDisp(i,opType)))
print("Applied Hooks successfully!!")
def layerSizeDisp(self,count,operation='x'):
def hook(self,ip,op):
print(operation+'('+str(count)+')','->',list(op[0].size()))
return hook
def removeHooks(self):
for handle in self.handle:
handle.remove()
print("Removed all hooks")
def printSizes(self):
temp_exp = torch.rand(1,3,300,300).cuda()
self.module(temp_exp)
Class Hook initializes the following things:
- Handles: This will store the registered hooks for all the layers
- Module: Module who’s layers need to be registered.
- seq_name: This variable acts as a switch which helps to register ‘all’ or specific sequential layers of the network
ApplyHooks function will check for our switch variable if it’s all or layer specific name. If it’s ‘all’ then it will register hooks call on all the layers using register_forward_hook. This is pytorch’s inbuilt function for registering forward hook. According to the pytorch’s documentation:
register_forward_hook(hook)
- It should have the following signature:
hook(module, input, output) -> None
-
It should not modify the input or output.
- It returns a handle which can be removed by calling
handle.remove()
Let’s look at the example of how the above Hook class will work to help us print layer sizes.
x = Hook(model)
ApplyHooks() will register forward hook call on ‘all’ the layers present in the module. During this call we are passing a function layerSizeDisplay() which takes counter and type of operation i.e. Conv2d, Linear, etc. Along with counter and operation, input and output of the layer are also passed into this function.
x.applyHooks()
Applied Hooks successfully!!
Since registered forward hook calls are only executed on a forward pass of the neural network, Therefore we are going to create a random input of shape (1,3,300,300) and pass it to our neural network. All of this is taken care by function printSizes()
x.printSizes()
Conv2d(0) -> [32, 298, 298]
Conv2d(1) -> [64, 296, 296]
Conv2d(2) -> [128, 294, 294]
Linear(0) -> [1024]
Linear(1) -> [2]
x.removeHooks()
Removed all hooks