パーセプトロンで Iris データセットの 2 クラス分類
はじめに
わかりやすいパターン認識の第 2 章でパーセプトロンが説明されていたので実装してみた。
学習データは scikit-learn が提供している Iris データセットを使った。
Iris (アヤメ) データセットは Setosa, Versicolour, Virginica の 3 品種 (クラス) のパターンがまとめられていて、特徴量としてがくの長さ、幅、同じく花弁の長さ、幅の 4 次元が選択されている。
試しに花弁の長さと幅を軸にとってデータを散布図に起こしてみた。
そしてそれぞれの品種の画像が以下。
確かに Setosa, Versicolour, Virginica の順に花弁が大きい・・・気がする。
コード
パーセプトロンは線形分離可能なクラスしか分類できないらしいので、花弁の大きさで線形分離できそうな Setosa と Versicolour を分類してみた。
import os from matplotlib import pyplot as plt from matplotlib.font_manager import FontProperties import numpy as np from sklearn import datasets, model_selection class Perceptron: def __init__(self, x_dim, rho=1e-3): self.w = np.random.randn(x_dim + 1) self.rho = rho def train(self, data, label): while True: # shuffle perm = np.random.permutation(len(data)) data, label = data[perm], label[perm] classified = True for x, y in zip(list(data), list(label)): pred = self.predict(x) if pred != y: classified = False # update weight x = np.array(list(x) + [1]) self.w = self.w - pred * self.rho * x if classified: break def predict(self, x): x = np.array(list(x) + [1]) return 1 if np.dot(self.w, x) > 0 else -1 if __name__ == '__main__': dataset = datasets.load_iris() x_train, y_train = dataset.data, dataset.target perceptron = Perceptron(x_dim=2) # preprocess train data mask = np.bitwise_or(y_train == 0, y_train == 1) x_train = x_train[mask][:, 2:] y_train = y_train[mask] y_train = np.array([-1 if y == 0 else 1 for y in y_train]) # train perceptron.train(x_train, y_train) # display fp = FontProperties(fname=r'C:\Windows\Fonts\meiryo.ttc', size=12) x_c0 = x_train[y_train == -1] x_c1 = x_train[y_train == 1] plt.scatter(x_c0[:, 0], x_c0[:, 1], label='Setosa') plt.scatter(x_c1[:, 0], x_c1[:, 1], label='Versicolour') w = perceptron.w y = lambda x: w[0] / -w[1] * x + w[2] / -w[1] x = np.arange(1, 5, 0.1) plt.plot(x, y(x)) plt.xlabel('花弁の長さ', fontproperties=fp) plt.ylabel('花弁の幅', fontproperties=fp) plt.title('パーセプトロンによるアヤメ科の花の分類', fontproperties=fp) plt.legend() plt.show()
実行結果
こんな感じな図が描けた。ちゃんと分類できそうな境界が引けている。
ちなみに、Versicolour と Virginica だと線形分離できないので学習時に無限ループしてしまった。
まとめ
- パーセプトロンで Iris データセットを分類してみた
- Setosa と Versicolour は花弁の大きさで決定境界を引けた
- Versicolour と Virginica は花弁の大きさでは決定境界を引けなかった