Pytorchで学習が進まない原因

pytorchを用いたdeep learningにおいて、lossが変化はするが減少傾向にならず、ランダムに変化するというバグに遭遇した。その原因についてメモ代わりに述べておく。

症状

それぞれの画像に対して、0から1の間の実数がラベル付けされているデータセットを考える。与えられた画像に対して、実数のラベルを予想するモデルを作るために、教師あり学習をおこなった。torchのバージョンは2.0.1。

epoch毎のlossの値を見たところ、一向に学習が進んでいなかった。しかしながら、lossの値は毎回変化するので、parameterの更新自体は行われている。このバグはloss functionを変えても同様に生じた。

原因

modelからのoutputと、labelのtorch.tensorのサイズが異なった。具体的には正解ラベルのデータをlabel、modelの出力をoutputとすると、

  • label.size() : (batch_size, )
  • output.size() : (batch_size, 1)

という違い。配列トータルのデータ数に違いはないけれども、配列の次元が異なっていた。

対処方法

labelをあらかじめ2次元の配列に変換したり、torch.flatten(output) するなどして、labelとoutputのサイズを一致させる。

コメント

タイトルとURLをコピーしました