Converting New Neural Network Architectures to TensorRT — AdderNet to TensorRT

Chinthaka Gamanayake
Analytics Vidhya
Published in
7 min readMar 9, 2021

--

When it comes to the deployment and inference of pre-trained neural network models in production hardware and Edge AI Devices, we have to reduce the model's memory footprint (size) to fit the edge device’s memory. Also, we have to reduce the number of computations in model to achieve higher throughput, while keeping the required accuracy. To that end, we can use the following model compression methodologies.

  • Pruning redundant weights, filters, and layers of the model.
  • Quantizing weights of the model. (float64 to float32, float16, int8)
  • Developing an efficient model architecture.
  • Knowledge distillation.

If you are not familiar with the above methodologies it's better to read the literature a little bit and get a basic idea about them since that will help you deploy high-performance models in your AI production line.

NVIDIA TensorRT

NVIDIA TensorRT is an SDK for high-performance deep learning inference. It includes a deep learning inference optimizer and runtime that delivers low latency and high throughput for deep learning inference applications. In other words, TensorRT does the model compression job for us. As you may already know, convolutional neural network (CNN) deployment to NVIDIA Jetson development kits requires converting our CNN model to NVIDIA-TensorRT format. Then it can run efficiently in NVIDIA Jetson deployment kits such as TX1, TX2, Nano, AGX, and NX.

AdderNet

AdderNet is a type of CNNs that replaces conventional Convolution Filters with novel Adder Filters. Normally Add operation is computationally cheaper than the Multiply operation. Adder Filters use this concept and replace the Multiplications with Additions in Convolutional Filters. Convolution Layers consisting of Adder Filters are called Adder Layers.

Convolution Operation — calculates the cross-correlation between the input feature map and the Convolution Filter.

Adder Operation — First, calculates the absolute L1-norm distance between the input feature map and Adder Filter. Second, take the negative value of the absolute value calculated.

Adder Operation -> |2–1| + |4–2| + |2–4| + |1–5| = 9 * (-1) = -9

Convolution Operation -> (2*1) + (4*2) + (2*4) + (1*5) = 23

Adder Op vs Convolution Op

AdderNet uses a full-precision gradient to update the Adder Filters in the backpropagation process. Also, it increases the learning rate of Adder Filters using adaptive learning rate methodology. Apart from that Adder Layers should be followed by Batch-Normalization layers to scale the negative outputs to a better range to work with Relu activations in conventional CNN architectures.

AdderNetMnist

For my experiment, I’m using a handcrafted AdderNet neural network architecture implemented in PyTorch for MNIST classification. I’ll call this model as AdderNetMnist from here onwards. AdderNetMnist is illustrated in the following code snippet. This architecture consists of two adder layers followed by batch normalization layers and max pooling layers. Then there are two dense layers for SoftMax classification. You can find this implementation in the source file addernet_mnist.py in my Github repo.

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.adder1 = adder.adder2d(1, 20, kernel_size=5, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(20)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.adder2 = adder.adder2d(20, 50, kernel_size=5, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm2d(50)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(800, 500)
self.relu1 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = self.adder1(x)
x = self.bn1(x)
x = self.pool1(x)
x = self.adder2(x)
x = self.bn2(x)
x = self.pool2(x)
x = x.view(-1, 800)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return F.softmax(x, dim=1)

The implementation of the Adder Layer (adder2d) can be found in the original AdderNet research paper implementation. This implementation is only done using the CPU and not implemented in CUDA for GPU acceleration.

Converting AdderNet to TensorRT

To convert a model architecture to TensorRT, first, we have to train our existing model using the required dataset (MNIST in my case) and save the network with its weights for future usage.

Then we are going to convert the saved model into TensorRT. There are two methods of working with saved models from different DL frameworks in TensorRT.

  1. Tensorflow — tf2onnx conversion, TF-TRT conversion, TF-UFF conversion.
  2. PyTorch or Other frameworks — Convert to ONNX, then convert to TRT.

For my experiment, I used the second method. First, I tried converting my PyTorch Model directly to ONNX using the following script addernet_mnist_onnx.py.

model = MnistModel()
model.network.load_state_dict(torch.load('./saved_models/addernet_mnist.pth'))
model.network.to('cuda')

dummy_input = torch.randn(1, 1, 28, 28, device='cuda')

# convert pytorch model to onnx format
torch.onnx.export(model.network, dummy_input,
"./saved_models/addernet_mnist.onnx",
verbose=True, opset_version=OPSET)

Unfortunately torch.onnx.export method gives an error because ONNX does not support the conversion of the novel Adder Layers in my neural network architecture. If we have a new layer type or new operation type in our neural network architecture, most of you!!! might end up in this stage with ONNX errors. In order to proceed from this stage we need to implement the custom layers or operations not supported by TensorRT from scratch using TensorRT APIs to make it convertible.

In this case, we can extend TensorRT functionalities by implementing custom layers using the IPluginV2 interface given by TensorRT C++ and Python APIs. Custom layers are often referred to as plugins in TensorRT. If you are familiar with OOP programming, IPluginV2 is an Interface or a Parent Class. What we have to do is create a child class that inherits the functionalities of the parent class. So our custom layer class can extend and implement the virtual method in IPluginV2 base class according to the functionality of our custom NN layer.

For my experiment, I have implemented the Adder Layer as Adder2dPlugin.h in my repo. Adder2dPlugin child class extends IPluginV2 base class. To create an object from Adder2dPlugin, TensortRT gives us another Interface called IPluginCreator. Adder2dPluginCreator is the child class that extends IPluginCreator. Implementation for both of these classes can be found in the source file Adder2dPlugin.cu. In summary, these two classes are responsible for defining the feature map dimensions, data types, managing GPU memory, and forward inference calculation for our Adder Layer. Please have a look at the Adder2dPlugin::enqueue() method which does the forward inference calculation using a custom parallel CUDA algorithm designed by me.

Implementation of the parallel CUDA algorithm for a custom layer includes the knowledge of parallel programming patterns in the CUDA programming interface. I have mapped the Adder Operation using a CUDA kernel __global__ void AdderFilter() which can be found inside the source Adder2dPlugin.cu. The CUDA mapping of my Adder operation can be illustrated as follows.

Mapping Adder Filter to the Concurrent CUDA Kernel

After successful implementation and testing of Adder2dPlugin using C++, we can create a C++ shared library object(.so file) containing this functionality. Then, we have to create the capability of loading this Adder2dPlugin as a Python package using Pybind11 (Adder2dPyTrt.cpp).

The next step is to import this Adder2dPlugin using Python. Please refer to the function named def get_adder2d_plugin() inside the source addernet_mnist_trt.py for more information.

At this point, we have finalized the implementation of our custom NN layer-Adder Layer. Now what we have to do is use this layer and convert the whole AdderNetMnist network architecture to TensorRT. To that end, we replicate the whole PyTorch AdderNetMnist network architecture in TensorRT Python API. The function def populate_network() inside addernet_mnist_trt.py does this job for us.

The hard job is over now. Therefore we can build the TensorRT inference engine using the pre-trained weights of the AdderNetMnist from the PyTorch model and run the inference. The whole development pipeline for this work can be illustrated in the following diagram.

Results and conclusion remarks

  • Results show, with an accuracy loss of 0.83%, the TensorRT model can achieve extremely lower inference latency (5000x) for the AdderNet Mnist model.
  • Higher accuracy in the TensorRT model shows the correctness of the Adder Layer implementation.
  • This implementation can be extended to deploy any other NN Architectures with Adder Layers such as ResNet20 and ResNet50 mentioned in the original research paper.
  • Currently, this is the only successful CUDA and TensorRT implementation for AdderNet available.
  • This work took 3 weeks for me to complete. I hope this article will help everyone in the AI industry and Academia who try things with NVIDIA Jetson Development Kits.

Next Step

I would like to integrate this CUDA implementation of the Adder Filter into existing DL libraries such as Tensorflow, and PyTorch. Here I have only implemented the forward inference in CUDA for deployment. But I need to implement the backpropagation in CUDA also for training purposes.

Give me a clap ❤️ if you liked this post:) Hope you will pull the code and try it yourself. https://github.com/chinthysl/AdderNetTensorRT.

Follow my GitHub page and start the projects important to you.

I do deep learning consultancy, research, and development, and love to work on interesting problems. If you have a project that we can collaborate on, then please contact me through my email: chin.pub.acc@gmail.com

Connect me via LinkedIn for my future work: https://www.linkedin.com/in/chinthaka-gamanayake/

I should mention and thank my colleague, Nuran who motivated me to complete this project within 2 weeks working a few hours a day!

--

--