Mastering Gomoku with AlphaZero: A Step-by-Step Guide
In this comprehensive article, we will delve into the fascinating world of game-playing artificial intelligence (AI) by exploring a custom implementation of AlphaZero tailored to the game of Gomoku.
Gomoku, also known as Five in a Row, is a strategic board game where the objective is to align five pieces in a row horizontally, vertically, or diagonally on an 11x11 grid.
The article will cover the essential building blocks of the AlphaZero algorithm, including the policy-value neural network, the Monte Carlo Tree Search (MCTS) algorithm, and the self-play training pipeline. We will walk you through each component in detail, providing insights into the underlying principles and techniques.
By the end of the article, you will have a thorough understanding of the AlphaZero framework and its application to Gomoku, as well as the knowledge to implement the algorithm for other board games or even more complex environments. Join us on this exciting journey to unlock the potential of game-playing AI and explore the cutting-edge techniques that have revolutionized the field.
Game.py
import numpy as np
import tkinter
class Board(object):
def __init__(self, **kwargs):
self.width = int(kwargs.get('width', 8))
self.height = int(kwargs.get('height', 8))
self.states = {} # board states, key:move as location on the board, value:player as pieces type
self.n_in_row = int(kwargs.get('n_in_row', 5)) # need how many pieces in a row to win
self.players = [1, 2] # player1 and player2
The __init__
method initializes an instance of a game board with some optional keyword arguments.
width: The width of the board, default is 8.
height: The height of the board, default is 8.
states
: A dictionary to store the board states, where the key represents the move (as a location on the board) and the value represents the player (as piece type).
n_in_row
: The number of pieces required in a row to win the game, default is 5.
players
: A list of players, player1 and player2, represented by integers 1 and 2.
def init_board(self, start_player=0):
if self.width < self.n_in_row or self.height < self.n_in_row:
raise Exception('board width and height can not less than %d' % self.n_in_row)
self.current_player = self.players[start_player] # start player
self.availables = list(range(self.width * self.height)) # available moves
self.states = {} # board states, key:move as location on the board, value:player as pieces type
self.last_move = -1
The init_board method initializes the game board with a given start player.
start_player (optional): Index of the starting player (0 or 1) from the players list.
First, it checks if the board dimensions (width and height) are greater than or equal to the n_in_row required to win the game. If not, it raises an exception.
It sets the current_player as the starting player based on the given index.
It initializes the availables list with all possible moves, represented as linear indices in the range from 0 to width * height — 1.
It initializes an empty dictionary states to store the board states, where the key represents the move (as a location on the board), and the value represents the player (as piece type).
It sets the last_move to -1, indicating that no move has been made yet.
def move_to_location(self, move):
"""
3*3 board's moves like:
6 7 8
3 4 5
0 1 2
and move 5's location is (1,2)
"""
h = move // self.width
w = move % self.width
return [h, w]
The move_to_location method converts a given move index to a corresponding row and column position (location) on the board.
move: The move index to be converted into a location.
The method performs the following steps:
Calculates the row h by performing integer division (//) of the given move by the board width.
Calculates the column w by calculating the remainder (%) of the given move divided by the board width.
Returns the row h and column w as a list [h, w], representing the location on the board.
For example, on a 3x3 board, if the given move is 5, the location will be [1, 2].
def location_to_move(self, location):
if(len(location) != 2):
return -1
h = location[0]
w = location[1]
move = h * self.width + w
if(move not in range(self.width * self.height)):
return -1
return move
The location_to_move method converts a given row and column position (location) on the board to a corresponding move index.
location: A list containing the row and column position [h, w] to be converted into a move index.
The method performs the following steps:
Checks if the length of the given location list is not equal to 2, then returns -1 as an invalid input.
Extracts the row h and column w from the location list.
Calculates the move index by multiplying the row h with the board width and adding the column w.
Checks if the calculated move index is not within the valid range of possible moves (0 to self.width * self.height — 1), then returns -1 as an invalid move.
Returns the calculated move index.
For example, on a 3x3 board, if the given location is [1, 2], the move index will be 5.
def current_state(self):
"""return the board state from the perspective of the current player
shape: 4*width*height"""
square_state = np.zeros((4, self.width, self.height))
if self.states:
moves, players = np.array(list(zip(*self.states.items())))
move_curr = moves[players == self.current_player]
move_oppo = moves[players != self.current_player]
square_state[0][move_curr // self.width, move_curr % self.height] = 1.0
square_state[1][move_oppo // self.width, move_oppo % self.height] = 1.0
square_state[2][self.last_move //self.width, self.last_move % self.height] = 1.0 # last move indication
if len(self.states)%2 == 0:
square_state[3][:,:] = 1.0
return square_state[:,::-1,:]
The current_state method returns the current board state from the perspective of the current player. The returned state is a 3D numpy array with a shape of (4, board width, board height).
The 3D array contains four 2D arrays, each representing different aspects of the board state:
The first 2D array (square_state[0]) represents the current player’s pieces on the board.
The second 2D array (square_state[1]) represents the opponent player’s pieces on the board.
The third 2D array (square_state[2]) indicates the last move made on the board.
The fourth 2D array (square_state[3]) is filled with 1.0 if it is player 1’s turn, otherwise, it remains filled with 0.0.
The method performs the following steps:
Initialize a zero-filled numpy array square_state with the shape (4, board width, board height).
Check if there are any states in the board. If so, proceed with the following steps.
Extract moves and players from the current board states.
Separate the current player’s moves and the opponent player’s moves.
Update the first 2D array of square_state with the current player’s moves.
Update the second 2D array of square_state with the opponent player’s moves.
Update the third 2D array of square_state with the last move.
If the length of the current states is even, it means it is player 1’s turn, so fill the fourth 2D array of square_state with 1.0.
Return the square_state numpy array with the last dimension reversed.
The returned state representation can be used as input for a neural network to make predictions or evaluate the current board state.
def do_move(self, move):
self.states[move] = self.current_player
self.availables.remove(move)
self.current_player = self.players[0] if self.current_player == self.players[1] else self.players[1]
self.last_move = move
The do_move method updates the board state after a player makes a move. Here’s how it works:
Update the self.states dictionary by adding the move as a key and the current player as its value. This means placing the current player’s piece at the specified move location on the board.
Remove the move from the list of available moves (self.availables) because it has now been played.
Switch the current player to the other player. If the current player is player 1, set the current player to player 2, and vice versa.
Update self.last_move with the current move, which will now be considered as the last move made on the board.
This method is used to apply a move to the board and update the game state accordingly, including the current player, available moves, and the last move.
def has_a_winner(self):
width = self.width
height = self.height
states = self.states
n = self.n_in_row
moved = list(set(range(width * height)) - set(self.availables))
if(len(moved) < self.n_in_row*2 - 1):
return False, -1
for m in moved:
h = m // width
w = m % width
player = states[m]
if (w in range(width - n + 1) and
len(set(states.get(i, -1) for i in range(m, m + n))) == 1):
return True, player
if (h in range(height - n + 1) and
len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
return True, player
if (w in range(width - n + 1) and h in range(height - n + 1) and
len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
return True, player
if (w in range(n - 1, width) and h in range(height - n + 1) and
len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
return True, player
return False, -1
The has_a_winner method checks whether there is a winner on the current board state.
It first gets a list of all the moves that have been made on the board, and if the number of moves made so far is less than the number of moves required to have a winner (n_in_row * 2–1), it returns that there is no winner.
Otherwise, it checks whether there is a winner by iterating over all the moves made so far (moved). It gets the position of the move (h and w), and the player who made the move (player).
Then, it checks whether there is a winner horizontally, vertically, or diagonally from the current move position by checking whether the set of n consecutive moves in the same row, column, diagonal, or antidiagonal all belong to the same player. If this is the case for any of the four directions, it returns True and the player who won.
If none of the four directions has a winner, it returns False and -1.
def game_end(self):
"""Check whether the game is ended or not"""
win, winner = self.has_a_winner()
if win:
return True, winner
elif not len(self.availables):#
return True, -1
return False, -1
The game_end method checks if the game has ended, either because one of the players has won or because the board is full and there is no winner. It returns a boolean indicating whether the game has ended, and an integer representing the winning player. If there is no winner, the winning player is represented by -1.
def get_current_player(self):
return self.current_player
The get_current_player function is a getter method that returns the player who will make the next move in the game.
class Point:
def __init__(self, x, y):
self.x = x;
self.y = y;
self.pixel_x = 30 + 30 * self.x
self.pixel_y = 30 + 30 * self.y
This class represents a point on the game board. It has two attributes, x and y, which represent the coordinates of the point on the board. It also has two other attributes, pixel_x and pixel_y, which represent the coordinates of the point in pixels on the screen. The pixel coordinates are calculated based on the x and y coordinates, assuming that each point on the board is 30 pixels apart. The constructor takes in the x and y coordinates as parameters and initializes all four attributes accordingly.
class Game(object):
"""
game server
"""
def __init__(self, board, **kwargs):
self.board = board
This is a constructor method for the GameUI class. It takes in a board parameter which is an instance of the Board class. It also accepts other keyword arguments.
def click1(self, event): #click1 because keyword repetition
current_player = self.board.get_current_player()
if current_player == 1:
i = (event.x) // 30
j = (event.y) // 30
ri = (event.x) % 30
rj = (event.y) % 30
i = i-1 if ri<15 else i
j = j-1 if rj<15 else j
move = self.board.location_to_move((i, j))
if move in self.board.availables:
self.cv.create_oval(self.chess_board_points[i][j].pixel_x-10, self.chess_board_points[i][j].pixel_y-10, self.chess_board_points[i][j].pixel_x+10, self.chess_board_points[i][j].pixel_y+10, fill='black')
self.board.do_move(move)
The code defines a function that handles mouse click events on the game board. If the current player is player 1, the function identifies the cell on the board that was clicked by the user and checks if it is a valid move. If it is a valid move, it updates the graphical representation of the board by drawing a black stone in the clicked cell, and updates the internal state of the board to reflect the new move.
def run(self):
current_player = self.board.get_current_player()
end, winner = self.board.game_end()
if current_player == 2 and not end:
player_in_turn = self.players[current_player]
move = player_in_turn.get_action(self.board)
self.board.do_move(move)
i, j = self.board.move_to_location(move)
self.cv.create_oval(self.chess_board_points[i][j].pixel_x-10, self.chess_board_points[i][j].pixel_y-10, self.chess_board_points[i][j].pixel_x+10, self.chess_board_points[i][j].pixel_y+10, fill='white')
end, winner = self.board.game_end()
if end:
if winner != -1:
self.cv.create_text(self.board.width*15+15, self.board.height*30+30, text="Game over. Winner is {}".format(self.players[winner]))
self.cv.unbind('<Button-1>')
else:
self.cv.create_text(self.board.width*15+15, self.board.height*30+30, text="Game end. Tie")
return winner
else:
self.cv.after(100, self.run)
This code runs the game loop where it gets the current player, checks if the game has ended, and if the current player is a computer player, gets the action from the player and updates the board with the move made by the player. Then it checks again if the game has ended. If the game has ended, it prints the winner of the game, or if it’s a tie. Otherwise, it waits for 100ms and runs the loop again.
def graphic(self, board, player1, player2):
"""
Draw the board and show game info
"""
width = board.width
height = board.height
p1, p2 = self.board.players
player1.set_player_ind(p1)
player2.set_player_ind(p2)
self.players = {p1: player1, p2:player2}
window = tkinter.Tk()
self.cv = tkinter.Canvas(window, height=height*30+60, width=width*30 + 30, bg = 'white')
self.chess_board_points = [[None for i in range(height)] for j in range(width)]
for i in range(width):
for j in range(height):
self.chess_board_points[i][j] = Point(i, j);
for i in range(width): #vertical line
self.cv.create_line(self.chess_board_points[i][0].pixel_x, self.chess_board_points[i][0].pixel_y, self.chess_board_points[i][width-1].pixel_x, self.chess_board_points[i][width-1].pixel_y)
for j in range(height): #rizontal line
self.cv.create_line(self.chess_board_points[0][j].pixel_x, self.chess_board_points[0][j].pixel_y, self.chess_board_points[height-1][j].pixel_x, self.chess_board_points[height-1][j].pixel_y)
self.button = tkinter.Button(window, text="start game!", command=self.run)
self.cv.bind('<Button-1>', self.click1)
self.cv.pack()
self.button.pack()
window.mainloop()
The graphic method is a part of a GUI-based implementation of the game “Connect6” using the tkinter library. It creates a new window with a chessboard-like layout and two buttons representing the players. It takes in two player objects and the board dimensions as arguments. Inside the method, it initializes the player index, creates a canvas on the window and adds lines to form the board’s grid. It then creates a button to start the game and binds the canvas to a click event to enable user interaction with the game board. Finally, it calls the mainloop() function to keep the window open and wait for user input until the program is stopped.
def start_play(self, player1, player2, start_player=0, is_shown=1):
"""
start a game between two players
"""
if start_player not in (0,1):
raise Exception('start_player should be 0 (player1 first) or 1 (player2 first)')
self.board.init_board(start_player)
if is_shown:
self.graphic(self.board, player1, player2)
else:
p1, p2 = self.board.players
player1.set_player_ind(p1)
player2.set_player_ind(p2)
players = {p1: player1, p2:player2}
while(1):
current_player = self.board.get_current_player()
print(current_player)
player_in_turn = players[current_player]
move = player_in_turn.get_action(self.board)
self.board.do_move(move)
if is_shown:
self.graphic(self.board, player1.player, player2.player)
end, winner = self.board.game_end()
if end:
return winner
The start_play function starts a game between two players, where the player1 and player2 arguments are instances of the MCTSPlayer class. The function starts the game with the player specified by the start_player argument. If is_shown is set to 1, it calls the graphic function to display the game board, otherwise it plays the game without visualizing it. The game continues until there is a winner, and then the function returns the winner’s player index.
ef start_self_play(self, player, is_shown=0, temp=1e-3):
""" start a self-play game using a MCTS player, reuse the search tree
store the self-play data: (state, mcts_probs, z)
"""
self.board.init_board()
p1, p2 = self.board.players
states, mcts_probs, current_players = [], [], []
while(1):
move, move_probs = player.get_action(self.board, temp=temp, return_prob=1)
# store the data
states.append(self.board.current_state())
mcts_probs.append(move_probs)
current_players.append(self.board.current_player)
# perform a move
self.board.do_move(move)
end, winner = self.board.game_end()
if end:
# winner from the perspective of the current player of each state
winners_z = np.zeros(len(current_players))
if winner != -1:
winners_z[np.array(current_players) == winner] = 1.0
winners_z[np.array(current_players) != winner] = -1.0
#reset MCTS root node
player.reset_player()
if is_shown:
if winner != -1:
print("Game end. Winner is player:", winner)
else:
print("Game end. Tie")
return winner, zip(states, mcts_probs, winners_z)
This code defines the method start_self_play of a class that implements a game-playing algorithm using Monte Carlo Tree Search (MCTS). The method starts a game between a player and itself, where the player is a MCTS agent that uses a neural network to guide its search. The game data, consisting of the state, the MCTS probabilities, and the outcome, is stored for use in training the neural network.
The method initializes the game board, then loops through the game, with each iteration representing one move. For each move, the MCTS agent selects an action using its get_action method, which performs a guided search of the game tree and returns the most promising action. The agent’s selection probabilities for each action are also saved. The selected action is then taken on the game board, and the loop continues until the game is over.
After the game is over, the outcome is determined and stored as the “winners_z” value, which is a 1D array that indicates the winning player from the perspective of the current player for each game state. If the game is tied, the “winners_z” array will contain all zeros. Finally, the MCTS agent’s search tree is reset, and the method returns the winner and the saved game data. If is_shown is True, the outcome of the game is printed to the console.
Human_play.py
from game import Board, Game
from tf_policy_value_net import PolicyValueNet
from mcts_alphaZero import MCTSPlayer
class Human(object):
"""
human player
"""
def __init__(self):
self.player = None
This code defines an __init__ method for the Human class. The Human class represents a human player and is used for playing the game manually. This method initializes an instance of the class and sets the player attribute to None. The player attribute will later be set to the player’s ID number (1 or 2) when the player chooses to make a move.
def set_player_ind(self, p):
self.player = p
The set_player_ind method sets the player’s index to p. This is used to specify whether a player is player 1 or player 2 in the game.
def get_action(self, board):
try:
location = input("Your move: ")
if isinstance(location, str):
location = [int(n, 10) for n in location.split(",")] # for python3
move = board.location_to_move(location)
except Exception as e:
move = -1
if move == -1 or move not in board.availables:
print("invalid move")
move = self.get_action(board)
return move
This code defines the get_action method for the Human class, which prompts the user to input their move and returns the move as an integer. The method first prompts the user to input a location in the format “row,column”, then converts the input to a list of integers using split() and a list comprehension. It then converts the location to a move using the location_to_move() method of the board object. If the input is invalid or the resulting move is not available on the board, the method recursively calls itself until a valid move is entered.
def __str__(self):
return "Human {}".format(self.player)
This method returns a string representation of a Human object, which includes the player’s index number (e.g. “Human 1” or “Human 2”).
def run():
n_row = 5
width, height = 11, 11
try:
board = Board(width=width, height=height, n_in_row=n_row)
game = Game(board)
################ human VS AI ###################
best_policy = PolicyValueNet(width, height, n_row)
mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) # set larger n_playout for better performance
human = Human()
# set start_player=0 for human first
game.start_play(human, mcts_player, start_player=1, is_shown=1)
except KeyboardInterrupt:
print('\n\rquit')
This code initializes a game of Gomoku with a 11x11 board and requires the player to connect 5 stones in a row to win. The game is set up for a human player to play against an AI using the Monte Carlo Tree Search algorithm. The AI player is created by initializing a PolicyValueNet and passing it into an MCTSPlayer along with a c_puct and n_playout value. The human player is created using the Human class. Finally, the game is started with the human player playing first, with the board displayed after each move. The game can be interrupted with a keyboard interrupt.
if __name__ == '__main__':
run()
The __name__ variable in Python holds the name of the current module. When the script is executed directly, the __name__ variable is set to ‘__main__’. Therefore, this code block is checking if the script is being executed directly or being imported as a module. If it’s being executed directly, it calls the run() function, which starts a game of Gomoku between a human player and an AI player that uses Monte Carlo Tree Search (MCTS) to choose its moves.
tf_policy_value_net.py
import tensorflow as tf
import os
class PolicyValueNet():
"""policy-value network """
def __init__(self, board_width, board_height, n_in_row):
tf.reset_default_graph()
self.board_width = board_width
self.board_height = board_height
self.model_file = './model/tf_policy_{}_{}_{}_model'.format(board_width, board_height, n_in_row)
self.sess = tf.Session()
self.l2_const = 1e-4 # coef of l2 penalty
self._create_policy_value_net()
self._loss_train_op()
self.saver = tf.train.Saver()
self.restore_model()
The __init__ method initializes the PolicyValueNet class, which is responsible for creating and training the policy-value network for the game. Here’s an overview of what the method does:
Reset the TensorFlow default graph to ensure no interference with previously created graphs.
Set the board width, board height, and number of pieces in a row required for a win, as well as the model file path.
Create a TensorFlow session.
Set the L2 regularization coefficient.
Call the _create_policy_value_net() method to create the policy-value network architecture.
Call the _loss_train_op() method to define the loss function and training operation.
Create a TensorFlow saver object to save and restore the model.
Call the restore_model() method to load the pre-trained model if available, or initialize the model with random weights if not.
def _create_policy_value_net(self):
"""create the policy value network """
with tf.name_scope("inputs"):
self.state_input = tf.placeholder(tf.float32, shape=[None, 4, self.board_width, self.board_height], name="state")
# self.state = tf.transpose(self.state_input, [0, 2, 3, 1])
self.winner = tf.placeholder(tf.float32, shape=[None], name="winner")
self.winner_reshape = tf.reshape(self.winner, [-1,1])
self.mcts_probs = tf.placeholder(tf.float32, shape=[None, self.board_width*self.board_height], name="mcts_probs")
# conv layers
conv1 = tf.layers.conv2d(self.state_input, filters=32, kernel_size=3,
strides=1, padding="SAME", data_format='channels_first',
activation=tf.nn.relu, name="conv1")
conv2 = tf.layers.conv2d(conv1, filters=64, kernel_size=3,
strides=1, padding="SAME", data_format='channels_first',
activation=tf.nn.relu, name="conv2")
conv3 = tf.layers.conv2d(conv2, filters=128, kernel_size=3,
strides=1, padding="SAME", data_format='channels_first',
activation=tf.nn.relu, name="conv3")
# action policy layers
policy_net = tf.layers.conv2d(conv3, filters=4, kernel_size=1,
strides=1, padding="SAME", data_format='channels_first',
activation=tf.nn.relu, name="policy_net")
policy_net_flat = tf.reshape(policy_net, shape=[-1, 4*self.board_width*self.board_height])
self.policy_net_out = tf.layers.dense(policy_net_flat, self.board_width*self.board_height, name="output")
self.action_probs = tf.nn.softmax(self.policy_net_out, name="policy_net_proba")
# state value layers
value_net = tf.layers.conv2d(conv3, filters=2, kernel_size=1, data_format='channels_first',
name='value_conv', activation=tf.nn.relu)
value_net = tf.layers.dense(tf.contrib.layers.flatten(value_net), 64, activation=tf.nn.relu)
self.value = tf.layers.dense(value_net, units=1, activation=tf.nn.tanh)
The _create_policy_value_net method defines the architecture of the policy-value network using TensorFlow. The method proceeds as follows:
Define input placeholders for the state, winner, and MCTS probabilities.
Create three convolutional layers (conv1, conv2, and conv3) with 32, 64, and 128 filters, respectively. These layers use a 3x3 kernel size, 1 stride, “SAME” padding, and ReLU activation functions.
Define the action policy layers:
a. Create a policy_net convolutional layer with 4 filters, a 1x1 kernel size, and ReLU activation.
b. Flatten the policy_net and connect it to a dense layer with the board width times board height number of units.
c. Apply the softmax function to the policy_net output to get action probabilities.
Define the state value layers:
a. Create a value_net convolutional layer with 2 filters and a 1x1 kernel size, using ReLU activation.
b. Flatten the value_net and connect it to a dense layer with 64 units and a ReLU activation function.
c. Connect the previous layer to another dense layer with a single output unit and a tanh activation function to obtain the value prediction.
This architecture combines the policy and value networks into a single model, which outputs both action probabilities and state value predictions.
def _loss_train_op(self):
"""
Three loss terms:
loss = (z - v)^2 + pi^T * log(p) + c||theta||^2
"""
l2_penalty = 0
for v in tf.trainable_variables():
if not 'bias' in v.name.lower():
l2_penalty += tf.nn.l2_loss(v)
value_loss = tf.reduce_mean(tf.square(self.winner_reshape - self.value))
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.policy_net_out, labels=self.mcts_probs)
policy_loss = tf.reduce_mean(cross_entropy)
self.loss = value_loss + policy_loss + self.l2_const*l2_penalty
# policy entropy,for monitoring only
self.entropy = policy_loss
# get the train op
self.learning_rate = tf.placeholder(tf.float32)
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.training_op = optimizer.minimize(self.loss)
The _loss_train_op method defines the loss function and the training operation for the policy-value network. The loss function consists of three components:
Value loss: The squared error between the predicted value self.value and the actual winner, i.e., (z — v)².
Policy loss: The cross-entropy loss between the predicted policy self.policy_net_out and the actual MCTS probabilities self.mcts_probs, i.e., pi^T * log(p).
L2 regularization: The L2 penalty term on the network’s weights (excluding biases) scaled by a constant self.l2_const, i.e., c||theta||².
The total loss is the sum of these three components. Additionally, the policy loss is used as an entropy measure for monitoring purposes.
The method also defines a placeholder for the learning rate and creates an Adam optimizer with the specified learning rate. Finally, the training operation is defined as the minimization of the total loss using the Adam optimizer.
def get_policy_value(self, state_batch):
# get action probs and state score value
action_probs, value = self.sess.run([self.action_probs, self.value],
feed_dict={self.state_input: state_batch})
return action_probs, value
The get_policy_value method takes a batch of input states (state_batch) and returns the action probabilities and state value predictions for each state. It does this by feeding the input states into the TensorFlow session and executing the self.action_probs and self.value tensors, which represent the predicted action probabilities and state values, respectively. The method returns these predictions as two separate arrays: action_probs and value.
def policy_value_fn(self, board):
"""
input: board
output: a list of (action, probability) tuples for each available action and the score of the board state
"""
legal_positions = board.availables
current_state = board.current_state()
act_probs, value = self.sess.run([self.action_probs, self.value],
feed_dict={self.state_input: current_state.reshape(-1, 4, self.board_width, self.board_height)})
act_probs = zip(legal_positions, act_probs.flatten()[legal_positions])
return act_probs, value[0][0]
The policy_value_fn method takes a board as input and returns a list of (action, probability) tuples for each available action and the score of the board state. It first retrieves the legal positions from the board and the current state. Then, it feeds the current state into the TensorFlow session and runs self.action_probs and self.value tensors to get the action probabilities and value predictions. It zips the legal positions with their corresponding probabilities and returns the action probabilities and the value of the board state.
def train_step(self, state_batch, mcts_probs_batch, winner_batch, lr):
feed_dict = {self.state_input : state_batch,
self.mcts_probs : mcts_probs_batch,
self.winner : winner_batch,
self.learning_rate: lr}
loss, entropy, _ = self.sess.run([self.loss, self.entropy, self.training_op],
feed_dict=feed_dict)
return loss, entropy
def restore_model(self):
if os.path.exists(self.model_file + '.meta'):
self.saver.restore(self.sess, self.model_file)
else:
self.sess.run(tf.global_variables_initializer())
The train_step method takes a batch of input states, MCTS probabilities, winner labels, and a learning rate, and performs a single training step on the policy-value network. The method creates a feed dictionary containing the provided inputs and runs the TensorFlow session with self.loss, self.entropy, and self.training_op tensors. The method returns the computed loss and entropy values.
mcts_alphaZero.py
import numpy as np
import copy
def softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
The softmax function is a mathematical function used to convert a vector of arbitrary real values into a vector of probabilities. The output vector has the same dimension as the input vector, and each element represents the probability of the corresponding input value. The function subtracts the maximum element of the input vector from all elements to avoid numerical instability and then exponentiates and normalizes the values. It is commonly used as an activation function in machine learning models, such as neural networks, to generate probabilities of different classes.
class TreeNode(object):
"""A node in the MCTS tree. Each node keeps track of its own value Q, prior probability P, and
its visit-count-adjusted prior score u.
"""
def __init__(self, parent, prior_p):
self._parent = parent
self._children = {} # a map from action to TreeNode
self._n_visits = 0
self._Q = 0
self._u = 0
self._P = prior_p
This code defines a class called TreeNode, which represents a node in the Monte Carlo Tree Search (MCTS) algorithm. The constructor initializes the node with a reference to its parent, a dictionary of children (initialized to an empty dictionary), and four instance variables: _n_visits (the number of times the node has been visited during MCTS), _Q (the total action value for the node), _u (an exploration bonus that encourages exploration of less-visited actions), and _P (the prior probability of selecting the node’s action). The _parent and _children attributes are used to construct the tree structure of the MCTS algorithm, while _n_visits, _Q, _u, and _P are used to compute the node’s action value and selection probability during MCTS.
def expand(self, action_priors):
"""Expand tree by creating new children.
action_priors -- output from policy function - a list of tuples of actions
and their prior probability according to the policy function.
"""
for action, prob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(self, prob)
The expand method is used in the Monte Carlo Tree Search algorithm and is responsible for expanding the search tree by creating new child nodes. The method takes in a list of (action, prior probability) tuples which represents the output of a policy function for a given state. For each (action, prob) tuple, the method checks if a child node with that action already exists. If not, a new TreeNode is created with the given probability and added as a child to the current node.
def select(self, c_puct):
"""Select action among children that gives maximum action value, Q plus bonus u(P).
Returns:
A tuple of (action, next_node)
"""
return max(self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct))
This function selects the child node that has the highest value among all the children of the current node. It evaluates the value using the UCT algorithm. It takes one parameter, c_puct, which is a constant that balances the exploration and exploitation terms in the UCT algorithm. The function returns a tuple containing the action that leads to the selected child node and the selected child node.
def update(self, leaf_value):
"""Update node values from leaf evaluation.
"""
# Count visit.
self._n_visits += 1
# Update Q, a running average of values for all visits.
self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
This method updates the Q value of the node with the average of the new leaf_value and the previous Q value. It also increments the visit count of the node by 1.
def update_recursive(self, leaf_value):
"""Like a call to update(), but applied recursively for all ancestors.
"""
# If it is not root, this node's parent should be updated first.
if self._parent:
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)
The update_recursive method updates the node values recursively for all ancestors of the current node. If the current node has a parent node, the parent’s update_recursive method is called first with the negative of the leaf_value parameter, and then the current node’s update method is called with the original leaf_value parameter. This process continues recursively up the tree until the root node is reached. This method is used to update the values of nodes in the MCTS tree based on the outcome of a game.
def get_value(self, c_puct):
"""Calculate and return the value for this node: a combination of leaf evaluations, Q, and
this node's prior adjusted for its visit count, u
c_puct -- a number in (0, inf) controlling the relative impact of values, Q, and
prior probability, P, on this node's score.
"""
self._u = c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits)
return self._Q + self._u
This function calculates and returns the value of a node in a Monte Carlo Tree Search algorithm. The value is calculated as a combination of the leaf evaluations, Q, and this node’s prior probability, P, which is adjusted for its visit count and the exploration parameter, c_puct. The exploration parameter controls the trade-off between exploitation and exploration, and it determines the relative impact of the values, Q, and prior probability, P, on the node’s score.
def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded).
"""
return self._children == {}
This method checks whether the current node is a leaf node or not. A leaf node is a node in the Monte Carlo Tree Search algorithm that has not been expanded, meaning that no children nodes have been added to it. The method returns True if the node is a leaf node, and False otherwise.
def is_root(self):
return self._parent is None
The is_root method is used to check whether the node is the root node or not in the tree. It returns a boolean value of True if the node is the root node, and False otherwise.
Are you passionate about machine learning, AI, and data science? Stay ahead of the curve by joining our exclusive newsletter! We deliver the latest insights, breakthroughs, and industry news right to your inbox, keeping you informed and inspired. As a member, you'll have access to expert analysis, exciting research developments, and practical tips to help you navigate the ever-evolving world of AI. Don't miss out on this unique opportunity to learn, grow, and be part of a thriving community of like-minded individuals. Subscribe to our newsletter today and fuel your curiosity in the fascinating realm of machine learning, AI, and data science!