Murayama blog.

プログラミング教育なブログ

強化学習 - Tic-Tac-Toe

強化学習 - Tic-Tac-Toe

三目並べ、マルバツゲーム、Tic-Tac-Toeというそうです。強化学習(Q-Learning)のまとめとしてチャレンジしてみました。Googleで「Tic-Tac-Toe」と検索すると三目並べで遊べます。

先に結果

ランダムな相手(後手)には80%近い確率で勝てるようになりました。でも実際に対戦してみると「ちょっとかしこいかな?」くらいの印象です。今日の勢いで作ったのでプログラムの細かいところに不備があるかも?しれません。。あまり参考にならないかも。

設定など

  • アルゴリズム
    • Q-Learning
  • 報酬
    • 勝ち:1
    • 負け:-1
    • 引き分け:0
      • 引き分けも多いので、報酬としてプラスマイナスがあるのも良いのかも。
  • 状態
    • 3**9 = 19683とおり
  • 行動
    • 9マスあるので9とおり
      • キーボード入力時は0〜8
      • 既に入力済みの場所を選択した場合は、ランダムで配置するようにしています。
        • このせいで勝率はやや落ちるかも。
  • パラメータなど
    • 学習回数(エピソード数):2,000,000くらい(他のサイトを見るともっと少なく良さそう)
    • 学習率:0.1
    • gamma:0.9

所感

  • マルバツゲーム本体のプログラムと、強化学習プログラムのインタフェース設計に少し迷った。
    • OpenAI Gymのインタフェースに習って、マルバツゲームにstep関数を実装した。
  • 学習後の勝率が60%で停滞した。
    • εの更新式を間違っていたため、Qテーブルの値に依存した学習を繰り返していた模様
  • そもそもマルバツゲーム本体のロジックに不備があったり。
    • テストコードを書こう。
  • Qテーブルの初期値も迷った。
    • -1〜1のようなランダムな値も試したが、あまり効果はなさそうだったのでとりあえず0にした。
  • 常に先手("x")をAIとした。
    • 学習済みのAIを後手で起動すると精度が下がる。

作成したファイル

まずは強化学習用のプログラムです。

  • tictactoe.py
    • マルバツゲーム本体プログラム
  • train_tictactoe.py
    • Q-Learningで学習するプログラム
    • ランダムな打ち手(後手)と繰り返し対決します。
    • 学習後Qテーブルの内容をnumpyのファイルとして出力します。

以下は学習済みプログラムを動作させるためのプログラムです。

  • play_test_tictactoe.py
    • 学習済みのQテーブルを使って検証するプログラム
  • play_tictactoe.py
    • キーボード入力で対決できるプログラム

tictactoe.py

  • マルバツゲーム本体プログラム
  • 先手、後手ともにランダムで動作します。
import numpy as np

class TicTacToe:
    def __init__(self, printable=False):
        self.reset(printable)

    def show(self):
        if self.printable:
            print(self.board[0], "|", self.board[1], "|" ,self.board[2])
            print("----------" )
            print(self.board[3], "|", self.board[4], "|" ,self.board[5])
            print("----------" )
            print(self.board[6], "|", self.board[7], "|" ,self.board[8])

    def reset(self, printable=None) :
        self.board = [" ", " ", " ", " ", " ", " ", " ", " ", " "]
        self.player1 = "x"
        self.player2 = "o"
        self.player = self.player1
        self.done = False
        if printable is not None:
            self.printable = printable
        return self.getObservation()

    def put(self, number) :
        if number < 0 or 8 < number:
            return False
        if self.board[number] != " ":
            return False
        self.board[number] = self.player
        return True

    def judge(self):
        if self.win() or self.draw():
            self.done = True
        return self.done

    def win(self):
        return self.player == self.board[0] and self.player == self.board[1] and self.player == self.board[2] \
                or self.player == self.board[3] and self.player == self.board[4] and self.player == self.board[5] \
                or self.player == self.board[6] and self.player == self.board[7] and self.player == self.board[8] \
                or self.player == self.board[0] and self.player == self.board[3] and self.player == self.board[6] \
                or self.player == self.board[1] and self.player == self.board[4] and self.player == self.board[7] \
                or self.player == self.board[2] and self.player == self.board[5] and self.player == self.board[8] \
                or self.player == self.board[0] and self.player == self.board[4] and self.player == self.board[8] \
                or self.player == self.board[2] and self.player == self.board[4] and self.player == self.board[6]

    def draw(self):
        return (" " in self.board) == False

    def changePlayer(self):
        if self.player == self.player1:
            self.player = self.player2
        else:
            self.player = self.player1

    def printStart(self):
        if self.printable:
            print("Tic Tac Toe Start!", flush=True)

    def printEnd(self):
        if self.printable:
            print("Tic Tac Toe End!", flush=True)

    def printChoice(self):
        if self.printable:
            print("Choice! (0,1,2,3,4,5,6,7,8)", flush=True)

    def printChoiceInvalid(self):
        if self.printable:
            print("Your choice is invalid", flush=True)

    def printWin(self):
        if self.printable:
            print(self.player, "Win!", flush=True)

    def printDraw(self):
        if self.printable:
            print("Draw!", flush=True)

    def printSpace(self):
        if self.printable:
            print(flush=True)

    def getRandomAction(self):
        return np.random.randint(0, 9)

    def start(self):
        self.printStart()
        while True:
            action = self.getRandomAction()
            (_, _, done, _) = self.step(action)
            self.show()
            self.printSpace()
            if done:
                self.show()
                if self.win():
                    self.printWin()
                else:
                    self.printDraw()
                break
        self.printEnd()

    def step(self, action):
        while True:
            if self.put(action) :
                break
            else:
                action = self.getRandomAction()
        reward = 0
        if self.judge() :
            if self.win():
                if self.player == self.player1:
                    reward = 1
                else:
                    reward = -1
            else:
                reward = 0
        else:
            self.changePlayer()
            reward = 0
        return (self.getObservation(), reward, self.done, {})

    def getObservation(self):
        return (self.board, self.player, self.player1, self.player2)

if __name__ == "__main__":
    ttt = TicTacToe(True)
    ttt.start()

train_tictactoe.py

import numpy as np
from tictactoe import TicTacToe

import numpy as np

def get_action(q_table, state, epsilon):
    if np.random.uniform(0, 1) <= epsilon:
        return np.random.randint(0, 9)
    else:
        a = np.where(q_table[state] == q_table[state].max())[0]
        return np.random.choice(a)

def board_to_state(board):
    state = 0
    for i in range(0, 9):
        if board[i] == 'o':
            state = state + 3**i * 2
        elif board[i] == 'x':
            state = state + 3**i * 1
        else:
            state = state + 3**i * 0
    return state

def update_q_learning(state, ation, reward, next_state, q_table):
    eta = 0.1
    gamma = 0.9
    if reward != 0:
        q_table[state,action] = q_table[state,action] + \
                                eta * (reward - q_table[state,action])
    else:
        q_table[state,action] = q_table[state,action] + \
                                eta * (reward + gamma * np.max(q_table[state,:]) - q_table[state,action])
    return q_table

np.set_printoptions(precision=6, suppress=True)

prefix = "q_table_data_"
episode = 100_000
# episode = 2_000_000

# q_table = np.random.uniform(low=-1, high=1, size=(3**9, 9))
q_table = np.zeros((3**9, 9))
# q_table = np.load('q_table_dataa_5000000.npy')

ttt = TicTacToe()
threshold = 10000
initial_epsilon = 0.5
win = 0
draw = 0
for i in range(1, episode):
    epsilon = initial_epsilon * (episode - i) / episode
    my_turn = True
    observation = ttt.reset()
    while True:
        state = board_to_state(observation[0])
        action = None
        if my_turn:
            action = get_action(q_table, state, epsilon)
        else:
            action = ttt.getRandomAction()

        (observation, reward, done, _) = ttt.step(action)

        if my_turn:
            next_state = board_to_state(observation[0])
            q_table = update_q_learning(state, action, reward, next_state, q_table)

        if done:
            if ttt.win():
                if ttt.player == ttt.player1:
                    win = win + 1
            else:
                draw = draw + 1
            break

        my_turn = not my_turn
    if i % threshold == 0:
        lose = threshold - win - draw
        print("episode", i, "/", episode, "win", win, "draw", draw, "lose", lose)
        win = 0
        draw = 0

print(q_table)
np.save(prefix + str(episode), q_table)

play_test_tictactoe.py

  • 学習結果を確認するプログラムです。
  • 学習済みのQテーブルをロードして、10000回の試行結果を出力します。
from tictactoe import TicTacToe

import numpy as np

def get_action(q_table, state):
    a = np.where(q_table[state] == q_table[state].max())[0]
    return np.random.choice(a)

def board_to_state(board):
    state = 0
    for i in range(0, 9):
        if board[i] == 'o':
            state = state + 3**i * 2
        elif board[i] == 'x':
            state = state + 3**i * 1
        else:
            state = state + 3**i * 0
    return state

max_episode = 10000
q_table = np.load('q_table_data_5000000.npy')
ttt = TicTacToe(False)
win = 0
draw = 0
ttt.printStart()
for i in range(1, max_episode):
    my_turn = True
    observation = ttt.reset()
    while True:
        state = board_to_state(observation[0])
        action = None
        if my_turn:
            action = get_action(q_table, state)
            # action = np.random.randint(0, 9)
        else:
            ttt.printChoice()
            action = np.random.randint(0, 9)
        (observation, reward, done, _) = ttt.step(action)
        ttt.show()
        ttt.printSpace()
        ttt.printSpace()
        if done:
            if ttt.win():
                ttt.printWin()
                if ttt.player == ttt.player1:
                    win = win + 1
            else:
                ttt.printDraw()
                draw = draw + 1
            break
        my_turn = not my_turn
print("episode", max_episode, "win", win, "draw", draw, "lose", max_episode - win - draw)

play_tictactoe.py

  • ユーザと対戦用のプログラムです。
from tictactoe import TicTacToe

import numpy as np

def get_action(q_table, state):
    a = np.where(q_table[state] == q_table[state].max())[0]
    return np.random.choice(a)

def board_to_state(board):
    state = 0
    for i in range(0, 9):
        if board[i] == 'o':
            state = state + 3**i * 2
        elif board[i] == 'x':
            state = state + 3**i * 1
        else:
            state = state + 3**i * 0
    return state

q_table = np.load('q_table_5000000.npy')

ttt = TicTacToe(True)
win = 0
draw = 0
ttt.printStart()
my_turn = True
observation = ttt.reset()
while True:
    state = board_to_state(observation[0])
    action = None
    if my_turn:
        action = get_action(q_table, state)
        # action = np.random.randint(0, 9)
    else:
        ttt.printChoice()
        action = int(input().strip())
        # action = np.random.randint(0, 9)
    (observation, reward, done, _) = ttt.step(action)
    ttt.show()
    ttt.printSpace()
    ttt.printSpace()
    if done:
        if ttt.win():
            ttt.printWin()
            if ttt.player == ttt.player1:
                win = win + 1
        else:
            ttt.printDraw()
            draw = draw + 1
        break
    my_turn = not my_turn

参考

qiita.com

イプシロンの更新ロジックを参考にさせてもらいました。

data.gunosy.io

改めて読むとわかりやすいです。既に入力済みの場所を選択した場合の制御も入れると勝率上がりそう。エピソード数は10000回で良さそう。