MNIST データセットの読み込み方

機械学習を学習する上で学習データを準備するのに苦労する場面がよくある。

今回は MNIST データセットと呼ばれる、機械学習ベンチマークでよく使われるデータセットの使い方をまとめる。

MNIST データセット

MNIST データセット は手書き数字文字データをまとめたデータセットであり、訓練用に 60,000 枚、テスト用に 10,000 枚の画像データ、そしてそれぞれの正解データ (画像がどの数字を表しているか) が用意されている。

データの形式は独自のバイナリで、ページの下のほうにフォーマットが解説されている。プログラムで扱う際はこのバイナリをパースして使う。

コード

ページの解説に従ってバイナリをパースするコードが以下。Python でバイナリをパースするのは struct モジュールを使った。

gist7b55be06fba08e00a51edf9e4f2eb207

読み込んだ画像の表示

画像は matplotlib モジュールで表示できる。表示コードと結果は以下の通り。

gistf9ec5c6229fa65402b8bbe012d23c887

f:id:ntsujio:20180211205121p:plain

画像は「5」のデータ。ちゃんと読み込めてるっぽい。

scikit-learn で読み込む

実は MNIST データセットは自分でコードを書かなくてもバイナリを読み込める。

機械学習ライブラリである scikit-learn を使うとこんな感じに書ける。

from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original', data_home='.')

img = mnist.data[0].reshape(28, 28)

plt.imshow(img, cmap='gray')
plt.show()

fetch_mldata メソッドのドキュメントを見ると、データは mldata.org から取得すると書いてある。

mldata.org は MNIST データセットだけでなく、様々なデータセットを公開している。また有用なデータセットをアップロードして貢献することもできる。

MNIST (Original) のページを見ると CC0 ライセンスと記載されており、学習にとても有用である。感謝。

まとめ

  • 機械学習の学習データには MNIST データセットが便利
  • MNIST データセットの読み込みには scikit-learn が便利
  • 便利なものを用意してくれている先人に感謝

参考