研究会

機械学習、データベース、分散システム、その他技術的なことを書く研究会です

パーセプトロンで Iris データセットの 2 クラス分類

はじめに

shop.ohmsha.co.jp

わかりやすいパターン認識の第 2 章でパーセプトロンが説明されていたので実装してみた。

学習データは scikit-learn が提供している Iris データセットを使った。

Iris (アヤメ) データセットは Setosa, Versicolour, Virginica の 3 品種 (クラス) のパターンがまとめられていて、特徴量としてがくの長さ、幅、同じく花弁の長さ、幅の 4 次元が選択されている。

試しに花弁の長さと幅を軸にとってデータを散布図に起こしてみた。

f:id:ntsujio:20180214011431p:plain

そしてそれぞれの品種の画像が以下。

Iris setosa - Wikipedia

f:id:ntsujio:20180214011224j:plain

Iris versicolor - Wikipedia

f:id:ntsujio:20180214010903j:plain

Iris virginica - Wikipedia

f:id:ntsujio:20180214010952j:plain

確かに Setosa, Versicolour, Virginica の順に花弁が大きい・・・気がする。

コード

パーセプトロンは線形分離可能なクラスしか分類できないらしいので、花弁の大きさで線形分離できそうな Setosa と Versicolour を分類してみた。

import os
from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties
import numpy as np
from sklearn import datasets, model_selection


class Perceptron:
    def __init__(self, x_dim, rho=1e-3):
        self.w = np.random.randn(x_dim + 1)
        self.rho = rho

    def train(self, data, label):
        while True:
            # shuffle
            perm = np.random.permutation(len(data))
            data, label = data[perm], label[perm]

            classified = True

            for x, y in zip(list(data), list(label)):
                pred = self.predict(x)
                if pred != y:
                    classified = False

                    # update weight
                    x = np.array(list(x) + [1])
                    self.w = self.w - pred * self.rho * x

            if classified:
                break

    def predict(self, x):
        x = np.array(list(x) + [1])
        return 1 if np.dot(self.w, x) > 0 else -1


if __name__ == '__main__':
    dataset = datasets.load_iris()

    x_train, y_train = dataset.data, dataset.target

    perceptron = Perceptron(x_dim=2)

    # preprocess train data
    mask = np.bitwise_or(y_train == 0, y_train == 1)
    x_train = x_train[mask][:, 2:]
    y_train = y_train[mask]
    y_train = np.array([-1 if y == 0 else 1 for y in y_train])

    # train
    perceptron.train(x_train, y_train)

    # display
    fp = FontProperties(fname=r'C:\Windows\Fonts\meiryo.ttc', size=12)

    x_c0 = x_train[y_train == -1]
    x_c1 = x_train[y_train == 1]
    plt.scatter(x_c0[:, 0], x_c0[:, 1], label='Setosa')
    plt.scatter(x_c1[:, 0], x_c1[:, 1], label='Versicolour')

    w = perceptron.w
    y = lambda x: w[0] / -w[1] * x + w[2] / -w[1]
    x = np.arange(1, 5, 0.1)
    plt.plot(x, y(x))

    plt.xlabel('花弁の長さ', fontproperties=fp)
    plt.ylabel('花弁の幅', fontproperties=fp)
    plt.title('パーセプトロンによるアヤメ科の花の分類', fontproperties=fp)

    plt.legend()
    plt.show()

実行結果

こんな感じな図が描けた。ちゃんと分類できそうな境界が引けている。

f:id:ntsujio:20180214012353p:plain

ちなみに、Versicolour と Virginica だと線形分離できないので学習時に無限ループしてしまった。

まとめ

  • パーセプトロンで Iris データセットを分類してみた
  • Setosa と Versicolour は花弁の大きさで決定境界を引けた
  • Versicolour と Virginica は花弁の大きさでは決定境界を引けなかった

参考