PyTorch Matrix Multiplication: How To Do A PyTorch Dot Product
PyTorch Matrix Multiplication - Use torch.mm to do a PyTorch Dot Product
< > Code:
You must be a Member to view code
Access all courses and lessons, gain confidence and expertise, and learn how things work and how to use them.
or Log In
This video will show you how to use PyTorch’s torch.mm operation to do a dot product matrix multiplication.
First, we import PyTorch.
Then we check what version of PyTorch we are using.
We are using PyTorch version 0.4.1.
Let’s create our first matrix we’ll use for the dot product multiplication.
tensor_example_one = torch.Tensor( [ [1, 1, 1], [2, 2, 2], [3, 3, 3] ])
We use torch.Tensor, and it’s going to be a 3x3 matrix.
So the first row is full of 1s, the second row is full of 2s, the third row is full of 3s, and we assign this matrix to the Python variable tensor_example_one.
Let’s print this variable to see what we have.
We see that it’s a PyTorch tensor, we see that all our numbers are there, and we see that each one has a decimal point after it.
This is because torch.Tensor creates a tensor full of floating point numbers.
Next, we create our second matrix that we’ll use for the dot product multiplication.
tensor_example_two = torch.Tensor( [ [4, 5, 6], [4, 5, 6], [4, 5, 6] ])
We use torch.Tensor again.
It is a 3x3 matrix.
This time, the first column is full of 4s, the second column is full of 5s, and the third column is full of 6s.
This will make our multiplication easier to do visually.
We print this tensor to see what’s inside:
And we see that it’s a PyTorch tensor.
We see 4s, 5s, 6s, and again, because this creates floating point tensors, we see that there is a decimal point after all the 4s, decimal point after all the 5s, and decimal point after all the 6s.
We can now do the PyTorch matrix multiplication using PyTorch’s torch.mm operation to do a dot product between our first matrix and our second matrix.
tensor_dot_product = torch.mm(tensor_example_one, tensor_example_two)
Remember that matrix dot product multiplication requires matrices to be of the same size and shape.
Because we’re multiplying a 3x3 matrix times a 3x3 matrix, it will work and we don’t have to worry about that.
Now, let’s visually check the PyTorch matrix multiplication result.
We see 12, 15, 18; 24, 30, 36; 36, 45, 54.
All right, does this 12 make sense? So 1x4, 1x4, 1x4.
The addition of that is just 4+4+4, which is 12.
Then we do this row times this column.
So 1x5, 1x5, 1x5, and the addition of that.
So 5+5+5 is 15.
Similarly, for the first row, third column, 1x6, 1x6, 1x6.
So 6+6+6 is 18.
So this result makes sense.
And because that was 1, when we do the 2, so this row times this column times this column times this column because it’s just 1x2, we would expect these numbers to be the double of the first row.
So 12x2 is 24, 15x2 is 30, 18x2 is 36.
Perfect! That makes sense.
Then the last row, which is the third row which contains 3s times the first column times the second column times the third column, we would expect it to be a multiple of 3 of this row.
So 12x3 is 36, 15x3 is 45, 18x3 is 54.
Perfect! We were able to use PyTorch’s torch.mm operation to do a dot product matrix multiplication.