強化学習 - Tic-Tac-Toe
強化学習 - Tic-Tac-Toe
三目並べ、マルバツゲーム、Tic-Tac-Toeというそうです。強化学習(Q-Learning)のまとめとしてチャレンジしてみました。Googleで「Tic-Tac-Toe」と検索すると三目並べで遊べます。
先に結果
ランダムな相手(後手)には80%近い確率で勝てるようになりました。でも実際に対戦してみると「ちょっとかしこいかな?」くらいの印象です。今日の勢いで作ったのでプログラムの細かいところに不備があるかも?しれません。。あまり参考にならないかも。
設定など
- アルゴリズム
- Q-Learning
- 報酬
- 勝ち:1
- 負け:-1
- 引き分け:0
- 引き分けも多いので、報酬としてプラスマイナスがあるのも良いのかも。
- 状態
- 3**9 = 19683とおり
- 行動
- 9マスあるので9とおり
- キーボード入力時は0〜8
- 既に入力済みの場所を選択した場合は、ランダムで配置するようにしています。
- このせいで勝率はやや落ちるかも。
- 9マスあるので9とおり
- パラメータなど
- 学習回数(エピソード数):2,000,000くらい(他のサイトを見るともっと少なく良さそう)
- 学習率:0.1
- gamma:0.9
所感
- マルバツゲーム本体のプログラムと、強化学習プログラムのインタフェース設計に少し迷った。
- OpenAI Gymのインタフェースに習って、マルバツゲームにstep関数を実装した。
- 学習後の勝率が60%で停滞した。
- εの更新式を間違っていたため、Qテーブルの値に依存した学習を繰り返していた模様
- そもそもマルバツゲーム本体のロジックに不備があったり。
- テストコードを書こう。
- Qテーブルの初期値も迷った。
- -1〜1のようなランダムな値も試したが、あまり効果はなさそうだったのでとりあえず0にした。
- 常に先手("x")をAIとした。
- 学習済みのAIを後手で起動すると精度が下がる。
作成したファイル
まずは強化学習用のプログラムです。
- tictactoe.py
- マルバツゲーム本体プログラム
- train_tictactoe.py
- Q-Learningで学習するプログラム
- ランダムな打ち手(後手)と繰り返し対決します。
- 学習後Qテーブルの内容をnumpyのファイルとして出力します。
以下は学習済みプログラムを動作させるためのプログラムです。
- play_test_tictactoe.py
- 学習済みのQテーブルを使って検証するプログラム
- play_tictactoe.py
- キーボード入力で対決できるプログラム
tictactoe.py
- マルバツゲーム本体プログラム
- 先手、後手ともにランダムで動作します。
import numpy as np class TicTacToe: def __init__(self, printable=False): self.reset(printable) def show(self): if self.printable: print(self.board[0], "|", self.board[1], "|" ,self.board[2]) print("----------" ) print(self.board[3], "|", self.board[4], "|" ,self.board[5]) print("----------" ) print(self.board[6], "|", self.board[7], "|" ,self.board[8]) def reset(self, printable=None) : self.board = [" ", " ", " ", " ", " ", " ", " ", " ", " "] self.player1 = "x" self.player2 = "o" self.player = self.player1 self.done = False if printable is not None: self.printable = printable return self.getObservation() def put(self, number) : if number < 0 or 8 < number: return False if self.board[number] != " ": return False self.board[number] = self.player return True def judge(self): if self.win() or self.draw(): self.done = True return self.done def win(self): return self.player == self.board[0] and self.player == self.board[1] and self.player == self.board[2] \ or self.player == self.board[3] and self.player == self.board[4] and self.player == self.board[5] \ or self.player == self.board[6] and self.player == self.board[7] and self.player == self.board[8] \ or self.player == self.board[0] and self.player == self.board[3] and self.player == self.board[6] \ or self.player == self.board[1] and self.player == self.board[4] and self.player == self.board[7] \ or self.player == self.board[2] and self.player == self.board[5] and self.player == self.board[8] \ or self.player == self.board[0] and self.player == self.board[4] and self.player == self.board[8] \ or self.player == self.board[2] and self.player == self.board[4] and self.player == self.board[6] def draw(self): return (" " in self.board) == False def changePlayer(self): if self.player == self.player1: self.player = self.player2 else: self.player = self.player1 def printStart(self): if self.printable: print("Tic Tac Toe Start!", flush=True) def printEnd(self): if self.printable: print("Tic Tac Toe End!", flush=True) def printChoice(self): if self.printable: print("Choice! (0,1,2,3,4,5,6,7,8)", flush=True) def printChoiceInvalid(self): if self.printable: print("Your choice is invalid", flush=True) def printWin(self): if self.printable: print(self.player, "Win!", flush=True) def printDraw(self): if self.printable: print("Draw!", flush=True) def printSpace(self): if self.printable: print(flush=True) def getRandomAction(self): return np.random.randint(0, 9) def start(self): self.printStart() while True: action = self.getRandomAction() (_, _, done, _) = self.step(action) self.show() self.printSpace() if done: self.show() if self.win(): self.printWin() else: self.printDraw() break self.printEnd() def step(self, action): while True: if self.put(action) : break else: action = self.getRandomAction() reward = 0 if self.judge() : if self.win(): if self.player == self.player1: reward = 1 else: reward = -1 else: reward = 0 else: self.changePlayer() reward = 0 return (self.getObservation(), reward, self.done, {}) def getObservation(self): return (self.board, self.player, self.player1, self.player2) if __name__ == "__main__": ttt = TicTacToe(True) ttt.start()
train_tictactoe.py
- 強化学習用のプログラムです。
import numpy as np from tictactoe import TicTacToe import numpy as np def get_action(q_table, state, epsilon): if np.random.uniform(0, 1) <= epsilon: return np.random.randint(0, 9) else: a = np.where(q_table[state] == q_table[state].max())[0] return np.random.choice(a) def board_to_state(board): state = 0 for i in range(0, 9): if board[i] == 'o': state = state + 3**i * 2 elif board[i] == 'x': state = state + 3**i * 1 else: state = state + 3**i * 0 return state def update_q_learning(state, ation, reward, next_state, q_table): eta = 0.1 gamma = 0.9 if reward != 0: q_table[state,action] = q_table[state,action] + \ eta * (reward - q_table[state,action]) else: q_table[state,action] = q_table[state,action] + \ eta * (reward + gamma * np.max(q_table[state,:]) - q_table[state,action]) return q_table np.set_printoptions(precision=6, suppress=True) prefix = "q_table_data_" episode = 100_000 # episode = 2_000_000 # q_table = np.random.uniform(low=-1, high=1, size=(3**9, 9)) q_table = np.zeros((3**9, 9)) # q_table = np.load('q_table_dataa_5000000.npy') ttt = TicTacToe() threshold = 10000 initial_epsilon = 0.5 win = 0 draw = 0 for i in range(1, episode): epsilon = initial_epsilon * (episode - i) / episode my_turn = True observation = ttt.reset() while True: state = board_to_state(observation[0]) action = None if my_turn: action = get_action(q_table, state, epsilon) else: action = ttt.getRandomAction() (observation, reward, done, _) = ttt.step(action) if my_turn: next_state = board_to_state(observation[0]) q_table = update_q_learning(state, action, reward, next_state, q_table) if done: if ttt.win(): if ttt.player == ttt.player1: win = win + 1 else: draw = draw + 1 break my_turn = not my_turn if i % threshold == 0: lose = threshold - win - draw print("episode", i, "/", episode, "win", win, "draw", draw, "lose", lose) win = 0 draw = 0 print(q_table) np.save(prefix + str(episode), q_table)
play_test_tictactoe.py
- 学習結果を確認するプログラムです。
- 学習済みのQテーブルをロードして、10000回の試行結果を出力します。
from tictactoe import TicTacToe import numpy as np def get_action(q_table, state): a = np.where(q_table[state] == q_table[state].max())[0] return np.random.choice(a) def board_to_state(board): state = 0 for i in range(0, 9): if board[i] == 'o': state = state + 3**i * 2 elif board[i] == 'x': state = state + 3**i * 1 else: state = state + 3**i * 0 return state max_episode = 10000 q_table = np.load('q_table_data_5000000.npy') ttt = TicTacToe(False) win = 0 draw = 0 ttt.printStart() for i in range(1, max_episode): my_turn = True observation = ttt.reset() while True: state = board_to_state(observation[0]) action = None if my_turn: action = get_action(q_table, state) # action = np.random.randint(0, 9) else: ttt.printChoice() action = np.random.randint(0, 9) (observation, reward, done, _) = ttt.step(action) ttt.show() ttt.printSpace() ttt.printSpace() if done: if ttt.win(): ttt.printWin() if ttt.player == ttt.player1: win = win + 1 else: ttt.printDraw() draw = draw + 1 break my_turn = not my_turn print("episode", max_episode, "win", win, "draw", draw, "lose", max_episode - win - draw)
play_tictactoe.py
- ユーザと対戦用のプログラムです。
from tictactoe import TicTacToe import numpy as np def get_action(q_table, state): a = np.where(q_table[state] == q_table[state].max())[0] return np.random.choice(a) def board_to_state(board): state = 0 for i in range(0, 9): if board[i] == 'o': state = state + 3**i * 2 elif board[i] == 'x': state = state + 3**i * 1 else: state = state + 3**i * 0 return state q_table = np.load('q_table_5000000.npy') ttt = TicTacToe(True) win = 0 draw = 0 ttt.printStart() my_turn = True observation = ttt.reset() while True: state = board_to_state(observation[0]) action = None if my_turn: action = get_action(q_table, state) # action = np.random.randint(0, 9) else: ttt.printChoice() action = int(input().strip()) # action = np.random.randint(0, 9) (observation, reward, done, _) = ttt.step(action) ttt.show() ttt.printSpace() ttt.printSpace() if done: if ttt.win(): ttt.printWin() if ttt.player == ttt.player1: win = win + 1 else: ttt.printDraw() draw = draw + 1 break my_turn = not my_turn
参考
イプシロンの更新ロジックを参考にさせてもらいました。
改めて読むとわかりやすいです。既に入力済みの場所を選択した場合の制御も入れると勝率上がりそう。エピソード数は10000回で良さそう。