Murayama blog.

プログラミング教育なブログ

Kerasで画像認識 - MNIST編

Kerasを使った画像認識のプログラムです。有名なMNISTデータ(手書き数字)を使ったものです。

MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.utils import to_categorical
from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 28x28 => 784
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)

# one-hot ex: 3 => [0,0,0,1,0,0,0,0,0,0]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

model = Sequential()
model.add(Dense(50, input_dim=784))
model.add(Activation('sigmoid'))
model.add(Dense(20))
model.add(Activation('sigmoid'))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(x_train, y_train, batch_size=32, validation_data=(x_test, y_test))

コードを書いて実行してみましょう。機械学習の開発環境にはJupyter Notebookがオススメです。

murayama.hatenablog.com

実行結果は次のようになります。テストデータで91%の正答率(val_acc)です。

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
10862592/11490434 [===========================>..] - ETA: 0sTrain on 60000 samples, validate on 10000 samples
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 6s - loss: 1.8023 - acc: 0.5735 - val_loss: 1.3870 - val_acc: 0.7524
Epoch 2/10
60000/60000 [==============================] - 5s - loss: 1.0939 - acc: 0.8008 - val_loss: 0.8456 - val_acc: 0.8457
Epoch 3/10
60000/60000 [==============================] - 5s - loss: 0.7209 - acc: 0.8618 - val_loss: 0.6233 - val_acc: 0.8767
Epoch 4/10
60000/60000 [==============================] - 5s - loss: 0.5610 - acc: 0.8814 - val_loss: 0.5006 - val_acc: 0.8945
Epoch 5/10
60000/60000 [==============================] - 5s - loss: 0.4770 - acc: 0.8892 - val_loss: 0.4278 - val_acc: 0.8998
Epoch 6/10
60000/60000 [==============================] - 6s - loss: 0.4267 - acc: 0.8960 - val_loss: 0.3993 - val_acc: 0.9026
Epoch 7/10
60000/60000 [==============================] - 8s - loss: 0.4097 - acc: 0.8970 - val_loss: 0.3959 - val_acc: 0.8984
Epoch 8/10
60000/60000 [==============================] - 6s - loss: 0.3875 - acc: 0.9016 - val_loss: 0.3856 - val_acc: 0.9044
Epoch 9/10
60000/60000 [==============================] - 5s - loss: 0.3673 - acc: 0.9035 - val_loss: 0.3514 - val_acc: 0.9076
Epoch 10/10
60000/60000 [==============================] - 5s - loss: 0.3458 - acc: 0.9070 - val_loss: 0.3296 - val_acc: 0.9112

初回実行時はMNISTデータのダウンロードが発生します。そのあと10回の学習が進んでいるのがわかります。

プログラムの解説

プログラムの詳細を見てみましょう。Kerasを使えばMNISTデータもKerasのAPIでダウンロードできます。

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 28x28 => 784
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)

# one-hot ex: 3 => [0,0,0,1,0,0,0,0,0,0]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

ここではダウンロード後のデータを全結合型のニューラルネットワークで処理できるようにデータを整形しています。

今回のプログラムは全結合型のニューラルネットワークです。KerasのDenseクラスで全結合レイヤーを作っています。

model.add(Dense(50, input_dim=784))

入力層のノード数はinput_dimで指定します。MNISTの画像データが28x28だから784になります。あと今回は0-9の10クラス分類なので出力層のノード数も10になります。

あとは活性化関数を指定して、レイヤーを並べています。中間層の活性化関数にはsigmoid関数、出力層の活性化関数にはsoftmax関数を指定しています。今回は多クラス分類なので出力層の活性化関数にはsoftmax関数を使っています。

モデルが完成したらコンパイルします。コンパイル時には損失関数とオプティマイザを指定します。損失関数には"categorical_crossentropy"、オプティマイザにはSGDを指定しています。metricsに指定した内容はエポック時に表示したい内容です。ここでは正答率(acc)を表示しています。

model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

ちなみに多クラス分類ではなく2値分類(yes/noみたいな)の場合は損失関数(loss)にbinary_crossentropyを指定します。その場合出力層の活性化関数をsigmoidにもできます。

あとは学習開始です。

history = model.fit(x_train, y_train, batch_size=32, validation_data=(x_test, y_test)))

MNISTの訓練データは60000件(テストデータは10000件)あるので、32件ずつランダムに取り出して勾配を求めます。求めた勾配よって重みが各ノードの重み・バイアスが更新されます(SGD)。引数にvalidation_dataを指定することでエポックごとにテストデータで検証(ホールドアウト検証)してくれます。デフォルトで10エポック(同じことを10回)学習します。

学習のグラフ化

matplotlibを使って学習の様子をグラフにしてみましょう。

import matplotlib.pyplot as plt

plt.ylim(0.0, 1)
plt.plot(history.history['acc'], label="acc")
plt.plot(history.history['val_acc'], label="val_acc")
plt.legend()

plt.show()

f:id:yamasahi:20171125154400p:plain

参考

オプティマイザについては以下のページが参考になります。

postd.cc