Nearest Neighbor 法で MNIST データセットの分類

はじめに

shop.ohmsha.co.jp

わかりやすいパターン認識の第 1 章で k-NN 法が説明されていたので、MNIST データセットを分類してみる。

コード

コードは以下の通り。特徴ベクトルとしてはピクセルごとの濃淡値をそのまま使い、距離尺度としてはユークリッド距離を用いた。

import os
from matplotlib import pyplot as plt
import numpy as np
from sklearn import datasets, model_selection

K = int(os.environ.get('NN_K', 5))
CLUSTER_SIZE = int(os.environ.get('NN_CLUSTER_SIZE', 10))


def distance(x, y):
    return np.sqrt(np.sum((x - y)**2))


def predict(x, prototypes):
    # calculate distances from x for each prototype
    distances = []
    for cls, data in prototypes.items():
        distances.extend([(int(cls), distance(x, y)) for y in data])

    distances.sort(key=lambda d: d[1])

    # count classes of top K
    class_count = [0] * 10
    for cls, _ in distances[:K]:
        class_count[cls] += 1

    # return the class having maximum class count
    return max(enumerate(class_count), key=lambda i: i[1])[0]


if __name__ == '__main__':
    mnist = datasets.fetch_mldata('MNIST original', data_home='.')

    x_train, x_test, y_train, y_test = model_selection.train_test_split(
        mnist.data, mnist.target, test_size=0.5, shuffle=True
    )

    # choose prototypes for each class
    prototypes = {}
    for i in range(10):
        data = x_train[y_train == i]
        prototypes[str(i)] = data[:CLUSTER_SIZE]

    # predict
    predicts = np.array([predict(x, prototypes) for x in x_test])
    accuracy = len(x_test[predicts == y_test]) / len(x_test)

    print(f"accuracy = {accuracy}")

結果

CLUSTER_SIZE を変化させて分類精度と処理時間をグラフにしてみた。

K は CLUSTER_SIZE = 1 のときは 1、それ以外の時は 10 に設定した。

f:id:ntsujio:20180213001728p:plain

CLUSTER_SIZE が 1000 のときは分類精度が 0.95 と、単純方法の割に予想以上に高い結果となった。

また処理時間は CLUSTER_SIZE に線形になった。最近傍を求めるためにクラスター内の全パターンとの距離を計算しているためであろう。

まとめ

  • k-NN で MNIST を分類してみた
  • 単純な方法でも以外と精度出た
  • データ量大事

参考