Normalize CIFAR10 Dataset Tensor
Use Torchvision Transforms Normalize (transforms.Normalize) to normalize CIFAR10 dataset tensors using the mean and standard deviation of the dataset
< > Code:
You must be a Member to view code
Access all courses and lessons, gain confidence and expertise, and learn how things work and how to use them.
or Log In
Now that we know how to convert CIFAR10 PIL images to PyTorch tensors, we may also want to normalize the resulting tensors.
Dataset normalization has consistently been shown to improve generalization behavior in deep learning models.
We will first want to import PyTorch and Torchvision.
We will then want to import torchvision.datasets as datasets and torchvision.transforms as transforms.
import torchvision.datasets as datasets
import torchvision.transforms as transforms
We will also want to check that our versions for both PyTorch 0.4.0 and Torchvision 0.2.1 are current.
We will then define our normalize function as follows: normalize equals transforms.Normalize.
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
The CIFAR10 tensors have three channels – red, green, and blue – and the argument is that the mean parameter specifies our target mean for each channel.
In this case, 0.5 for all three.
Similarly, the std parameter takes a list target standard deviations for each channel which we also specify here to be 0.5.
This tends to be a good starting point.
If we import the CIFAR10 set as usual, transforming the PIL images to tensors on import:
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
and pick out a tensor:
datapoint = cifar_trainset
then we can print that tensor to see what it looks like.
We can then apply our newly defined normalized transform to this tensor by calling normalize for that tensor as an argument.
We can see here that our normalization transform did in fact alter the tensor.
We could normalize the entire dataset by looping over it and calling normalize on each tensor individually.
However, this is not the cleanest way to include a normalization step when importing datasets from torchvision.
We should instead include normalize in the transform argument when importing the CIFAR10 set, and for that we will need to combine the two tensors and normalize transforms using transforms.Compose.