めるぱんブログ

とある医大生の勉強の記録です

PythonとKerasによるディープラーニング 第5章②

先ほどはMNISTを用いましたが、次はKaggleで提供されているDogs vs. Catsというデータセットを用いて学習をしていきます。

 

この最初の画像データをディレクトリにコピーする部分がどうしてもできずに毎回 invalid argumentのエラーが出てしまいました。いろいろと調べて治そうと努力したのですが結局改善できず、手動でディレクトリにコピペしてしまいました。

まあここは本質的な部分ではないので良しとします。(どなたかわかる方がいれば教えてください)

 

CNNモデルは前回MNSITに使用したものを少し変えたものを使用し、コンパイルではbinary_crossentropyを損失関数として使用し、ほかはいつもと同じです。

 

データの前処理ではImageDataGeneratorというものを使って画像の読み込みからRGBピクセルにデコード、浮動小数点型のテンソルに変換、1/255にし[0, 1]に収めるということができます。

このジェネレータはデータを無限に生成するため、モデルにfitさせるときに今までと違いfit_generatorというものを使用します。これを用いことで、次のエポックに進むタイミングを決めることができます。

例えばトータルのサンプル数が合計2000の場合、最初のジェネレータのバッチを20に設定したときにはこのfit_generatorのsteps_per_epochを100で設定することで20×100で2000となる、といった感じです。

また、検証用のデータも無限に生成されるため、指定しておく必要があります。

 

このモデルで学習した結果、図のようになりました。

f:id:melpan:20200415150819p:plain

訓練データと検証データでの正解率

f:id:melpan:20200415150851p:plain

訓練データと検証データでの損失値

結果は正解率は70%くらいで頭打ちとなり、過学習になってしまいました。

今回は元のデータが2000しかなかったので、データの拡張を行います。これは、同じくImageDataGeneratorを用いることで、元の画像を回転させたり平行移動させたりすることで少しずつ違う画像を生成するという作業です。

ただし、検証データは水増しすることは禁忌なので注意が必要です。

これによりバッチを少し大きくし、モデルにフィットさせます。

また、モデルの方にもドロップアウトを入れることで過学習を抑制します。

その結果、図のようになりました。

f:id:melpan:20200415151722p:plain

訓練データと検証データでの正解率改

f:id:melpan:20200415151747p:plain

訓練データと検証データでの損失値改

80%以上の正解率になりました。損失値はあらぶっていますがなかなかの結果です。

また、今回のモデルはmodel.save()を用いて保存しておきます。

 

実際にCNNを試してみましたが、データの水増しの部分は非常に重要で今後絶対必要になると思いました。

あと実際に学習しているときなのですが、パソコンがめちゃくちゃ熱くなって少し危険な気がしています…時間も1時間くらいかかってしまうのでいいデスクトップのPCがほしくなりますね。