Kerasで画像認識 - MNIST編
Kerasを使った画像認識のプログラムです。有名なMNISTデータ(手書き数字)を使ったものです。
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
from keras.models import Sequential from keras.layers import Dense, Activation from keras.utils import to_categorical from keras.datasets import mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() # 28x28 => 784 x_train = x_train.reshape(60000, 784) x_test = x_test.reshape(10000, 784) # one-hot ex: 3 => [0,0,0,1,0,0,0,0,0,0] y_train = to_categorical(y_train, 10) y_test = to_categorical(y_test, 10) model = Sequential() model.add(Dense(50, input_dim=784)) model.add(Activation('sigmoid')) model.add(Dense(20)) model.add(Activation('sigmoid')) model.add(Dense(10)) model.add(Activation('softmax')) model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) history = model.fit(x_train, y_train, batch_size=32, validation_data=(x_test, y_test))
コードを書いて実行してみましょう。機械学習の開発環境にはJupyter Notebookがオススメです。
実行結果は次のようになります。テストデータで91%の正答率(val_acc)です。
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz 10862592/11490434 [===========================>..] - ETA: 0sTrain on 60000 samples, validate on 10000 samples Train on 60000 samples, validate on 10000 samples Epoch 1/10 60000/60000 [==============================] - 6s - loss: 1.8023 - acc: 0.5735 - val_loss: 1.3870 - val_acc: 0.7524 Epoch 2/10 60000/60000 [==============================] - 5s - loss: 1.0939 - acc: 0.8008 - val_loss: 0.8456 - val_acc: 0.8457 Epoch 3/10 60000/60000 [==============================] - 5s - loss: 0.7209 - acc: 0.8618 - val_loss: 0.6233 - val_acc: 0.8767 Epoch 4/10 60000/60000 [==============================] - 5s - loss: 0.5610 - acc: 0.8814 - val_loss: 0.5006 - val_acc: 0.8945 Epoch 5/10 60000/60000 [==============================] - 5s - loss: 0.4770 - acc: 0.8892 - val_loss: 0.4278 - val_acc: 0.8998 Epoch 6/10 60000/60000 [==============================] - 6s - loss: 0.4267 - acc: 0.8960 - val_loss: 0.3993 - val_acc: 0.9026 Epoch 7/10 60000/60000 [==============================] - 8s - loss: 0.4097 - acc: 0.8970 - val_loss: 0.3959 - val_acc: 0.8984 Epoch 8/10 60000/60000 [==============================] - 6s - loss: 0.3875 - acc: 0.9016 - val_loss: 0.3856 - val_acc: 0.9044 Epoch 9/10 60000/60000 [==============================] - 5s - loss: 0.3673 - acc: 0.9035 - val_loss: 0.3514 - val_acc: 0.9076 Epoch 10/10 60000/60000 [==============================] - 5s - loss: 0.3458 - acc: 0.9070 - val_loss: 0.3296 - val_acc: 0.9112
初回実行時はMNISTデータのダウンロードが発生します。そのあと10回の学習が進んでいるのがわかります。
プログラムの解説
プログラムの詳細を見てみましょう。Kerasを使えばMNISTデータもKerasのAPIでダウンロードできます。
(x_train, y_train), (x_test, y_test) = mnist.load_data() # 28x28 => 784 x_train = x_train.reshape(60000, 784) x_test = x_test.reshape(10000, 784) # one-hot ex: 3 => [0,0,0,1,0,0,0,0,0,0] y_train = to_categorical(y_train, 10) y_test = to_categorical(y_test, 10)
ここではダウンロード後のデータを全結合型のニューラルネットワークで処理できるようにデータを整形しています。
今回のプログラムは全結合型のニューラルネットワークです。KerasのDenseクラスで全結合レイヤーを作っています。
model.add(Dense(50, input_dim=784))
入力層のノード数はinput_dimで指定します。MNISTの画像データが28x28だから784になります。あと今回は0-9の10クラス分類なので出力層のノード数も10になります。
あとは活性化関数を指定して、レイヤーを並べています。中間層の活性化関数にはsigmoid関数、出力層の活性化関数にはsoftmax関数を指定しています。今回は多クラス分類なので出力層の活性化関数にはsoftmax関数を使っています。
モデルが完成したらコンパイルします。コンパイル時には損失関数とオプティマイザを指定します。損失関数には"categorical_crossentropy"、オプティマイザにはSGDを指定しています。metricsに指定した内容はエポック時に表示したい内容です。ここでは正答率(acc)を表示しています。
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
ちなみに多クラス分類ではなく2値分類(yes/noみたいな)の場合は損失関数(loss)にbinary_crossentropyを指定します。その場合出力層の活性化関数をsigmoidにもできます。
あとは学習開始です。
history = model.fit(x_train, y_train, batch_size=32, validation_data=(x_test, y_test)))
MNISTの訓練データは60000件(テストデータは10000件)あるので、32件ずつランダムに取り出して勾配を求めます。求めた勾配よって重みが各ノードの重み・バイアスが更新されます(SGD)。引数にvalidation_dataを指定することでエポックごとにテストデータで検証(ホールドアウト検証)してくれます。デフォルトで10エポック(同じことを10回)学習します。
学習のグラフ化
matplotlibを使って学習の様子をグラフにしてみましょう。
import matplotlib.pyplot as plt plt.ylim(0.0, 1) plt.plot(history.history['acc'], label="acc") plt.plot(history.history['val_acc'], label="val_acc") plt.legend() plt.show()
参考
オプティマイザについては以下のページが参考になります。