Why doesn't want my simple pytorch network move on Cuda?

Multi tool use


Why doesn't want my simple pytorch network move on Cuda?
I built a simple network from a tutorial and I got this error:
RuntimeError: Expected object of type torch.cuda.FloatTensor but found
type torch.FloatTensor for argument #4 'mat1'
Any help? Thank you!
import torch
import torchvision
device = torch.device("cuda:0")
root = '.data/'
dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.out = torch.nn.Linear(28*28, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.out(x)
return x
net = Net()
net.to(device)
for i, (inputs, labels) in enumerate(dataloader):
inputs.to(device)
out = net(inputs)
1 Answer
1
TL;DR
This is the fix
inputs = inputs.to(device)
Why?!
There is a slight difference between torch.nn.Module.to()
and torch.Tensor.to()
: while Module.to()
is an in-place operator, Tensor.to()
is not. Therefore
torch.nn.Module.to()
torch.Tensor.to()
Module.to()
Tensor.to()
net.to(device)
Changes net
itself and moves it to device
. On the other hand
net
device
inputs.to(device)
does not change inputs
, but rather returns a copy of inputs
that resides on device
. To use that "on device" copy, you need to assign it into a variable, hence
inputs
inputs
device
inputs = inputs.to(device)
@Harkonnen glad I could help. When reading pytorch docs you should pay attention to "in-place" operations: some methods are in-place, some are not, and some have an "in-place" variants...
– Shai
5 hours ago
By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.
It works! Thank you so much! I spent several hours searching for an answer.
– Harkonnen
5 hours ago