pytorchで再学習する場合に生じるエラー

機械学習をしていると、再学習したくなる状況がしばしばあります。torch.save() を使ってモデルを保存し、読み込んで再学習をした際に次のようなバグに遭遇しました。

Traceback (most recent call last):
  File "/Users/(省略)", line 68, in <module>
    optimizer.step()
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/optim/optimizer.py", line 280, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/optim/optimizer.py", line 33, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/optim/adam.py", line 141, in step
    adam(
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/optim/adam.py", line 281, in adam
    func(params,
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/optim/adam.py", line 344, in _single_tensor_adam
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

「全てのテンソルが同一のデバイスにないといけなのに、cuda:0とcpu上にあるよ」と怒られています。シンプルに 「.to(device) のし忘れか?」と思いましたが、torch.save() を使って再学習する場合にのみ生じるタイプの特殊なバグだったので、メモを残します。

バグが出るコード

例えば次のようなmnistのコードを実行し、model.ckpt の形で学習済みmodelやoptimizerの情報を保存します。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

# epoch = ini_epoch ~ fin_epoch まで学習
ini_epoch = 0
fin_epoch = 9

#----------------------------------------------------------
num_batch = 128
learning_rate = 1e-3
image_size = 28*28
device = "cuda:0"  # gpu を選択
#----------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor()
    ])
train_dataset = datasets.MNIST(
    './data', 
    train=True, 
    download=True, 
    transform=transform 
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=num_batch,
    shuffle=True
)
#----------------------------------------------------------
# model の定義
class Net(nn.Module):
    def __init__(self, input_size, output_size):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 100)
        self.fc2 = nn.Linear(100, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.sigmoid(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net(image_size, 10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 

model.to(device)
model.train()

# 学習
for epoch in range(ini_epoch, fin_epoch + 1):
    loss_sum = 0
    for inputs, labels in train_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        inputs = inputs.view(-1, image_size)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()
    print(epoch, loss_sum / len(train_dataloader))

# modelなどの保存
torch.save(
    {  
        'epoch': fin_epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    },
    "model.ckpt"
)

保存したcheckpointファイルを読み込み、再学習をします。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

# epoch = ini_epoch ~ fin_epoch まで学習
ini_epoch = 10
fin_epoch = 19

(先ほどと同じなので省略)

model = Net(image_size, 10)#.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 

# モデルやoptimizerの読み込み
checkpoint = torch.load("model.ckpt")
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.load_state_dict(checkpoint["model_state_dict"])

model.to(device)
model.train()

# 再学習
(先ほどと同じなので省略)

このコードを実行すると、最初のバグが生じます。

原因

model をデバイスに送るタイミングの問題です。どうやら、optimizerを読み込む前にmodelをデバイスに送らなければならないようです。そこで、先ほどのコードの一部を次のように書き換えてください。

model.to(device)
checkpoint = torch.load("model.ckpt")
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.load_state_dict(checkpoint["model_state_dict"])

model.train()

これで問題なく動きます。

コメント

タイトルとURLをコピーしました