How To Use The view Method To Manage Tensor Shape In PyTorch
Use the PyTorch view method to manage Tensor Shape within a Convolutional Neural Network
< > Code:
You must be a Member to view code
Access all courses and lessons, gain confidence and expertise, and learn how things work and how to use them.
or Log In
A common issue that arises when designing any type of neural network is the output tensor of one layer having the wrong shape to act as the input tensor to the next layer.
Sometimes, this issue will raise an error, but in more insidious cases, the error will only be noticeable when one evaluates the performance of their trained model on a test set.
For the most part, careful management of layer arguments will prevent these issues.
However, there are cases where it is necessary to explicitly reshape tensors as they move through the network.
One such case is when the output tensor of a convolutional layer is feeding into a fully connected output layer as is the case in the displayed network.
x = self.layer2(x) x = x.view(-1, 32 * 16 * 16) x = self.fully_connected(x)
Let’s quickly follow the shape of the input tensor as it moves through the network as it is important to be able to do this at every step.
Initially, the input images for this network are 32x32 images with three color channels if we are using the CIFAR-10 data set.
So the input tensor is of the form batch size.
The batches are fed into this one elongated tensor by 3, the number of channels, by 32, the height of the images, and by 32, the width of the images.
The first layer takes the input tensor and applies the Conv2d operation to it with a kernel size of 3 and a padding of 1 producing 16 feature maps from the original 3 channels.
self.layer1.add_module("Conv1", nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=3, padding=1))
This is what is meant when the convolutional layer is said to have 16 output channels.
Hence, the resulting tensor is of the form batch size because we sure hope this hasn’t changed, by 16, the new number of channels, by 32, by 32 because a convolutional layer with padding 1 and kernel size 3 preserves height and width.
The second layer takes output from the first layer and modifies it further.
self.layer2.add_module("Conv2", nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, stride=2))
This time, the kernel size is 3, the padding is 1, but the stride is 2, so the shape of the output tensor is batch size by 32, the new number of output channels, by 16, by 16 since a convolution with kernel size 3, padding 1, and stride 2 cuts their height and width in half.
At this point, it will be common and good practice to apply average pooling to each of the 32 channels resulting in 32 scalars which are actually 1x1 tensors.
This would likely create a much better performing network but the purpose of this model is to illustrate the use of view.
Our goal is to change the current shape of our tensors so that they are capable of being fed into our fully connected layer.
This is where view comes in.
x = x.view(-1, 32 * 16 * 16)
Because we want to flatten our inputs out for the fully connected layer, we can bind all of our dimensions into one scalar.
In this case, 32 * 16 * 16 = 8,192.
One caveat, our tensor also has the first dimension which is the batch size.
We want to be able to keep our batch size constant while we flatten the other dimension.
Not doing so would definitely raise an error or tank model performance.
So we first pass view the argument -1 which tells view to infer the first dimension from the other dimensions, that is 8,192.
x = x.view(-1, 32 * 16 * 16)
View will infer that we want the first dimension to be the batch size and we are left with a tensor of dimension batch size by 8,192.
Lastly, we simply have our fully connected layer with an input dimension of 8,192 or 32 * 16 * 16 which is our reshaped dimension, and output dimension of 10 which corresponds to the 10 different possible classes in our data set.
self_fully_connected = nn.Linear(32 * 16 * 16, num_classes)
It is worth noting that specifying all of the arguments for this network required a lot of prior knowledge about the data set we are using.
Though it is possible to set up a network without this type of hard coding, when it is possible, it is a good idea as this will prevent a lot of difficult-to-debug errors.