強化学習の始め方
普段やらない、強化学習について少し勉強したのでメモしておきます。
参考書籍
先に参考書籍の紹介です。
Pythonによる深層強化学習入門 ChainerとOpenAI Gymではじめる強化学習
- 作者: 牧野浩二,西崎博光
- 出版社/メーカー: オーム社
- 発売日: 2018/08/17
- メディア: 単行本(ソフトカバー)
- この商品を含むブログを見る
DeepLearningの経験がある人にオススメです。コードによる説明が多く、プログラマー向けの書籍という感じでした。一冊目にちょうど良かったです。個人的にはQ-Learningのサンプルがとてもわかりやすかったです。OpenAI Gymのサンプルも豊富で手を動かす勉強に向いています。後半はRaspberry PIやArduinoを使ったデモも載っています。
つくりながら学ぶ! 深層強化学習 ~PyTorchによる実践プログラミング~
- 作者: 株式会社電通国際情報サービス小川雄太郎
- 出版社/メーカー: マイナビ出版
- 発売日: 2018/06/28
- メディア: 単行本(ソフトカバー)
- この商品を含むブログを見る
こちらも良書です。とても説明が丁寧で読みやすいです。技術の背景や用語の説明もしっかりしているので勉強になります。順番的には2冊目にちょうど良い印象です。あとサンプルプログラムにはアニメーションなども実装されているので、直感的でわかりやすいです。著者の方はブログやQiita等での情報発信も凄くてファンになりました。こんな本書けるようになりたいです。
あともう一冊手元にあるのですが、まだしっかり読めていないので割愛です。
今回勉強するまで、強化学習についてはなんとなく「準備が大変そう」みたいに思っていたのですが、実際やってみるとそうでもなく、今は良い書籍がたくさんあって、ライブラリも豊富で、とても学びやすくなっている印象を受けました。
強化学習の始め方
1週間勉強した感想です。私の場合は上記の2冊が学び始めにちょうど良かったです。
これから強化学習を学ぶ方は、強化学習について全体像をザックリ見ておくと良いです。強化学習にもいろんなアルゴリズムがあります。以下の記事が詳しいです。
アルゴリズムを俯瞰してみると名前にインパクトのあるDQNとか有名ですね。私のように、これから強化学習について学び始める人にとってはQ-Learning、SARSA、方策勾配法あたりから手を付けてみるのが良さそうです。
何を題材にするか
教師あり学習、たとえば画像認識の場合だとMNISTデータセット、簡単な分類問題、クラスタリングの場合はirisデータセットみたいに、すぐに学習に使えるデータセットがあると便利です。
強化学習においてはデータセット、というより、もう少し大きな枠組みになるので「環境」といった方で良いでしょうか。OpenAI Gymを使うと倒立振子(CartPole)やスペースインベーダー、ブロック崩しなどを題材に強化学習を始めることができます。
OpenAI Gymの環境設定も難しくはありませんが、Pythonのライブラリ管理の知識が必要だったり、環境設定特有のトラブルがついてくるので、もっと手軽に始めれる題材ないのかなーと思っていたらやっぱりあるんですね。「Skinner箱」というのが強化学習の入り口のようです。
スキナー箱
マウス(ネズミ)が餌をとるまでの物語です。スキナー箱の中には1匹のマウスがおり、餌を獲得するための2つのスイッチがあります。
- 電源スイッチ
- 押すたびにON/OFFが切り替わる
- 餌スイッチ
- 電源スイッチがONのときに餌が出る
スキナー箱には2つの状態(State)と2つの行動(Action)があります。
初期状態において、マウスはどちらのスイッチを押したら餌が出るのかわかないため、ランダムにスイッチを押すことになります。運が良ければ「電源スイッチ」=>「餌スイッチ」と押すことで、最短2ステップで餌を獲得することができます。
マウスは繰り返し餌の獲得に取り組むことで、学習によってテーブルの値を更新していきます。
以降は、参考書籍のプログラムを参考にスキナー箱について、3つのアルゴリズムの解法をまとめました。
方策勾配法(Policy Gradient Method)
方策(Policy)とは、エージェントがどのように振る舞うかを決めるルールのことです。エージェントとは今回でいうとマウス(の意思決定する部分)のことです。方策は表形式で表現したり、関数で表現したりします。ここでは状態と行動の2x2の表形式で方策を管理します。
import numpy as np def to_pi_softmax(theta): pi = np.zeros((theta.shape)) exp_theta = np.exp(1.0 * theta) for i in range(theta.shape[0]): pi[i,:] = exp_theta[i,:] / np.sum(exp_theta[i,:]) return pi def get_action(pi, state): return np.random.choice([0, 1], p=pi[state]) def get_next_state(state, action): if state == 0 and action == 0: return 1 elif state == 0 and action == 1: return 0 elif state == 1 and action == 0: return 0 else: return 1 def challenge(pi): state = 0 history = [] while True: action = get_action(pi, state) history.append([state, action]) if state == 1 and action == 1: break state = get_next_state(state, action) return history def update_policy_gradient(theta, pi, history): delta_theta = theta.copy() t = len(history) for i in range(theta.shape[0]): for j in range(theta.shape[1]): n_i = len([sa for sa in history if sa[0] == i]) n_ij = len([sa for sa in history if sa == [i, j]]) delta_theta[i, j] = (n_ij - pi[i, j] * n_i) / t return theta + 0.25 * delta_theta theta = np.array([[1.0, 1.0], [1.0, 1.0]]) pi = to_pi_softmax(theta) print(pi) for i in range(1, 100): history = challenge(pi) new_theta = update_policy_gradient(theta, pi, history) new_pi = to_pi_softmax(new_theta) print(len(history), end=" ", flush=True) if i % 10 == 0: print("\n", new_pi) theta = new_theta pi = new_pi print("\n", pi)
方策はPolicyのPをギリシャ文字のπとして表現することが一般的なようです。上記のプログラムの場合、thetaが学習で更新されるパラメータで、方策テーブル(π)では状態ごとにsoftmax関数を使うことで、各行動を割合として管理しています。
実行結果
$ python skinner_policy_gradient.py [[0.5 0.5] [0.5 0.5]] 2 10 6 4 7 2 7 4 7 2 [[0.5055308 0.4944692 ] [0.39641196 0.60358804]] 9 9 2 4 11 7 4 4 2 5 [[0.57791613 0.42208387] [0.3823193 0.6176807 ]] 3 7 10 2 4 4 8 2 2 2 [[0.6332647 0.3667353 ] [0.30543971 0.69456029]] 2 5 2 7 9 8 6 3 13 2 [[0.68658989 0.31341011] [0.33609984 0.66390016]] 5 4 7 2 2 2 2 12 2 6 [[0.69865129 0.30134871] [0.28321905 0.71678095]] 2 2 5 3 2 3 3 2 5 2 [[0.72277102 0.27722898] [0.21997818 0.78002182]] 5 4 2 4 2 2 3 4 2 4 [[0.72245457 0.27754543] [0.2008336 0.7991664 ]] 4 2 2 4 2 2 2 2 2 2 [[0.82114666 0.17885334] [0.16544254 0.83455746]] 6 3 4 2 8 4 2 2 2 2 [[0.85727318 0.14272682] [0.18526107 0.81473893]] 2 2 2 2 2 2 4 2 4 [[0.86500809 0.13499191] [0.15425348 0.84574652]]
実行結果には方策テーブルとマウスの試行回数を10回ずつ出力しています。学習が進むに連れてマウスは試行回数2回で餌にたどり着けるようになります。方策テーブルも状態0においては、行動0をとるようになり、状態1においては行動1をとるように学習できています。あと方策勾配法は強化学習でよく聞く「報酬」とか出てこないんですね。
SARSA
おまけでSARSAに置き換えてみました。SARSAや後のQ-Learningは価値反復法(value iteration)に分類されるようです(この辺、ググるといろんな説明がある)。価値反復法では、報酬や価値(状態価値、行動価値)、マルコフ決定過程やベルマン方程式というキーワードが出てきます。また時間があるときにまとめるかも。。
import numpy as np def get_action(q_table, state, epsilon): if np.random.uniform(0, 1) <= epsilon: return np.random.choice([0, 1]) else: if q_table[state, 0] == q_table[state, 1]: return np.random.choice([0, 1]) return np.argmax(q_table[state,:]) def get_next_state_and_reward(state, action): if state == 0 and action == 0: return (1, 0) elif state == 0 and action == 1: return (0, 0) elif state == 1 and action == 0: return (0, 0) else: return (1, 1) def update_sarsa(state, action, reward, next_state, next_action, q_table): eta = 0.1 gamma = 0.9 if reward == 1: q_table[state,action] = q_table[state,action] + eta * (reward - q_table[state,action]) else: q_table[state,action] = q_table[state,action] + eta * (reward + gamma * q_table[next_state,next_action] - q_table[state,action]) return q_table def challenge(q_table, epsilon): state = 0 action = get_action(q_table, state, epsilon) history = [] while True: history.append([state, action]) [next_state, reward] = get_next_state_and_reward(state, action) next_action = get_action(q_table, next_state, epsilon) q_table = update_sarsa(state, action, reward, next_state, next_action, q_table) if reward == 1: break state = next_state action = next_action return (q_table, history) np.set_printoptions(precision=6, suppress=True) q_table = np.zeros((2, 2)) print(q_table) epsilon = 1.0 for i in range(1, 101): epsilon = epsilon * 0.9 [q_table, history] = challenge(q_table, epsilon) print(len(history), end=" ", flush=True) if i % 10 == 0: print("\n", q_table)
このプログラムでは行動価値関数を変数q_tableで管理しています。q_tableは2x2の状態と行動の表データで初期値を0としています。学習が進むにつれてq_tableの値が調整されていきます。
SARSAの名前の由来はSARSAの更新式に必要なState、Action、Reward、next-State、next-Actionの5つの頭文字です。ここではSARSAの更新式であるupdate_sarsa関数の引数もその順で定義しています。
SARSAでは基本的にはq_tableに従って、ある状態における行動を決定するわけですが、一定の割合でランダムな行動をとるようにしています。これはε-greedy法という考え方に従うもので、より良い行動を探すための仕組みです。強化学習の世界には「探索と利用のトレードオフ(exploitation-exploration trade-offs)」という言葉もあるようです。深いです。
ランダムに動作するためのepsilonの割合は繰り返し(エピソード)ごとに0.5を掛けるものが多くありましたが、ここでは学習の様子(失敗するケース)を強調するために0.9を掛けています。
実行結果
$ python skinner_sarsa.py [[0. 0.] [0. 0.]] 8 5 3 2 7 4 4 2 4 2 [[0.216381 0.010251] [0.020069 0.651322]] 2 2 2 2 2 2 2 2 2 2 [[0.54006 0.010251] [0.020069 0.878423]] 2 2 2 2 2 2 2 2 2 2 [[0.732106 0.010251] [0.020069 0.957609]] 2 2 2 2 3 2 2 2 2 2 [[0.826678 0.07931 ] [0.020069 0.985219]] 2 2 2 2 2 2 2 2 2 2 [[0.86928 0.07931 ] [0.020069 0.994846]] 2 2 2 2 2 2 2 2 2 2 [[0.887492 0.07931 ] [0.020069 0.998203]] 2 2 2 2 2 2 2 2 2 2 [[0.895012 0.07931 ] [0.020069 0.999373]] 2 2 2 2 2 2 2 2 2 2 [[0.898042 0.07931 ] [0.020069 0.999782]] 2 2 2 2 2 2 2 2 2 2 [[0.899241 0.07931 ] [0.020069 0.999924]] 2 2 2 2 2 2 2 2 2 2 [[0.899709 0.07931 ] [0.020069 0.999973]]
q_tableの初期値は0としています。学習が進むにつれて、q_table[0][0]やq_table[1][1]の値が大きくなっているのがわかります。余談ですが、方策勾配法の方策piと比較すると、q_tableの状態ごとの値は割合ではないので加算したら1になるわけではないようです。
また学習が進むにつれて、ランダムに動作する割合を示すepsilonが小さくなるので不規則な行動はとらなくなります。そのためq_tableに従って最小の2ステップで餌にたどり着くことができています。
Q-Learning
さいごにQ-Learningです。SARSAとよく似ていて、行動価値関数(変数q_table)の更新式が少し異なります。SARSAでは更新式にnext-State、next-Actionの2つが必要でしたが、Q-Learningではnext-Stateにおける行動(Action)の中から値の最大値のものを選択するようにします。
import numpy as np def get_action(q_table, state, epsilon): if np.random.uniform(0, 1) <= epsilon: return np.random.choice([0, 1]) else: if q_table[state, 0] == q_table[state, 1]: return np.random.choice([0, 1]) return np.argmax(q_table[state,:]) def get_next_state_and_reward(state, action): if state == 0 and action == 0: return (1, 0) elif state == 0 and action == 1: return (0, 0) elif state == 1 and action == 0: return (0, 0) else: return (1, 1) def update_q_learning(state, action, reward, next_state, q_table): eta = 0.1 gamma = 0.9 if reward == 1: q_table[state,action] = q_table[state,action] + eta * (reward - q_table[state,action]) else: q_table[state,action] = q_table[state,action] + eta * (reward + gamma * np.max(q_table[next_state,:]) - q_table[state,action]) return q_table def challenge(q_table, epsilon): state = 0 action = get_action(q_table, state, epsilon) history = [] while True: history.append([state, action]) [next_state, reward] = get_next_state_and_reward(state, action) q_table = update_q_learning(state, action, reward, next_state, q_table) if reward == 1: break state = next_state action = get_action(q_table, state, epsilon) return (q_table, history) np.set_printoptions(precision=6, suppress=True) q_table = np.zeros((2, 2)) print(q_table) epsilon = 1.0 for i in range(1, 101): epsilon = epsilon * 0.9 [q_table, history] = challenge(q_table, epsilon) print(len(history), end=" ", flush=True) if i % 10 == 0: print("\n", q_table)
実行結果
$ python skinner_q_learning.py [[0. 0.] [0. 0.]] 2 3 6 2 3 2 6 2 2 3 [[0.270356 0.047601] [0.01882 0.651322]] 3 3 2 2 2 2 2 2 2 2 [[0.55888 0.08763 ] [0.01882 0.878423]] 2 2 2 2 2 2 4 4 3 2 [[0.759274 0.145173] [0.136828 0.957609]] 2 2 2 2 2 2 2 2 2 2 [[0.836151 0.145173] [0.136828 0.985219]] 2 2 2 2 2 2 2 2 2 2 [[0.872583 0.145173] [0.136828 0.994846]] 2 2 2 2 2 2 2 2 2 2 [[0.888643 0.145173] [0.136828 0.998203]] 2 2 2 2 2 2 2 2 2 2 [[0.895414 0.145173] [0.136828 0.999373]] 2 2 2 2 2 2 2 2 2 2 [[0.898182 0.145173] [0.136828 0.999782]] 2 2 2 2 2 2 2 2 2 2 [[0.89929 0.145173] [0.136828 0.999924]] 2 2 2 2 2 2 2 2 2 2 [[0.899726 0.145173] [0.136828 0.999973]]
結果はSARSAのときと同じようにq_tableの値が更新されているのがわかります。
その次の勉強
とりあえずはOpenAI Gymの題材にチャレンジするのが良さそうです。有名な倒立振子(CartPole)については書籍やインターネット上でサンプルもたくさん紹介されています。他にもToy textなるものもありました。私もいくつか触ってみましたが、FrozenLake問題というのは勉強するのにちょうど良い感じがしました。
https://gym.openai.com/envs/#toy_text
他にも調べているとブロック崩しゲームも自力で解けるみたいです。GPUマシンもクラウドでどうにかなるし。
今後は自分で何か強化学習の題材を作ってみようと思っています。面白いかどうかは別として、マルバツゲームや五目並べなどにチャレンジしてみようかと思っています。Q-Learningだと状態や行動、報酬の設計をゼロからできるようになれば世界が広がりそうです。
あとはアルゴリズムについてはコードだけでなく数式による理解も大事ですね。この辺はコツコツと。それからDQNなどの深層学習に取り組むのも面白そうです。今まではKerasしか使ったことなかったですが最近はPyTorchが良いみたいです。
まとめ
普段やらないことをやってみました。勉強になりました。