Training an Agent to play Pong using Reinforcement Learning
Using a video image for state information, RL is used to teach an Agent to play Pong
- 1. Introduction
- 2. Solution Proposal
- 3. Implementation of the Solution
The Atari 2600 is a home video game console from Atari, Inc. It was released in 1977. It is credited with popularizing the use of microprocessor-based hardware and games stored on ROM cartridges instead of dedicated hardware with games physically built into the unit.
Atari games have become popular benchmarks for AI systems, particularly reinforcement learning. OpenAI Gym has a large number of these games available. Internally it uses the Stella Atari Emulator.
This blog post is about training a RL agent to play the Atari game Pong. Pong is a two-dimensional sports game that simulates table tennis. The player controls an in-game paddle by moving it vertically across the left or right side of the screen. Two players can compete against each other, each controlling their own paddle. Players use the paddles to hit a ball back and forth. For the Atari 2600 version of Pong, a computer player (controlled by the 2600) is the opposing player.
2. Solution Proposal
As mentioned, we will use Reinforcement Learning (RL) to solve this problem. The agent is the entity that will be trained enabling it to provide inputs to the environment in the form of actions. The environment will be the game board, the opposing player's paddle, as well as the agent's paddle. The state of the environment will be derived from one or more images from the visual simulator. There will be a single state variable in the form of a cuboid. The cuboid has 3 dimensions: One for the height of the image, one for the width, and a third for either the number of color channels or the number of black-and-white frames.
3. Implementation of the Solution
To implement the environment we use the OpenAI Gym tools. We will use the TF-Agents python library from Google to implement the agent. Because we have discrete actions we will use the Deep Q Network (DQN) network as a function approximator for the agent.
This implementation allows the agent to give simple categorical commands to its own paddle. The goal is to get the opponent to miss as many shots as possible - each miss scores a point for the agent.
The code will run on the Google Colab platform. To start with, we install the tf-agents python package.
!pip install tf-agents
import numpy as np
import sklearn
import tensorflow as tf
from tensorflow import keras
import matplotlib as mpl
import PIL.Image
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from functools import partial
from gym.wrappers import TimeLimit
from tf_agents.environments import suite_gym, suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4
from tf_agents.utils import common
from tf_agents.metrics import tf_metrics
from tf_agents.metrics import py_metrics
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver
from tf_agents.eval.metric_utils import log_metrics
from tf_agents.policies.random_tf_policy import RandomTFPolicy
from tf_agents.trajectories.trajectory import to_transition
from tf_agents.utils.common import function
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.networks.q_network import QNetwork
from tf_agents.environments import tf_py_environment
import os
# from IPython.core.debugger import set_trace
import pickle
#repeatable runs:
np.random.seed(777)
tf.random.set_seed(777)
max_episode_frames = 108_000
env_name = "PongNoFrameskip-v4"
env = suite_atari.load(
env_name,
max_episode_steps=max_episode_frames/4,
gym_env_wrappers=[AtariPreprocessing, FrameStack4])
env.reset()
PIL.Image.fromarray(env.render())
Next we will setup two environments - one for training and the other for evaluation. Then we wrap the two environments for TF-Agents:
train_py_env = suite_atari.load( #for training
env_name,
max_episode_steps=max_episode_frames/4,
gym_env_wrappers=[AtariPreprocessing, FrameStack4])
eval_py_env = suite_atari.load( #for evaluation
env_name,
max_episode_steps=max_episode_frames/4,
gym_env_wrappers=[AtariPreprocessing, FrameStack4])
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
3.2.1 Network (DQN)
The QNetwork makes use of a preprocessing layer that divides the pixel values by 255. This division assists the neural network by normalizing the pixel values to between 0 and 1.
The parameter conv_layer_params
defines the CNN which forms the main part of the function approximation mechanism. It consists of 3 layers. Each layer is represented by a length-three tuple indicating (filters, kernel_size, stride). For example, the first layer has 32 filters, a kernel of 8x8 and a stride of 4.
The fc_layer_params
parameter defines the final fully-connected layer of the CNN and consists of 512 neurons.
# Q-Network, DQN
preprocessing_layer = keras.layers.Lambda(lambda obs: tf.cast(obs, np.float32)/255.)
# list of convolution layers parameters: Each item is a length-three tuple
# indicating (filters, kernel_size, stride):
conv_layer_params=[(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)]
fc_layer_params=[512]
q_net = QNetwork(
input_tensor_spec= train_env.observation_spec(),
action_spec= train_env.action_spec(),
preprocessing_layers= preprocessing_layer,
conv_layer_params= conv_layer_params,
fc_layer_params= fc_layer_params)
LEARNING_RATE = 0.25e-3
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=LEARNING_RATE)
COLLECT_STEPS_PER_ITERATION = 4
epsilon_fn = keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=1.0,
decay_steps=150_000//COLLECT_STEPS_PER_ITERATION,
end_learning_rate=0.01)
Now we instantiate and initialize the agent:
global_step = tf.compat.v1.train.get_or_create_global_step(); print(global_step) #.
TARGET_COLLECT_STEPS_PER_ITERATION = 2000
agent = DqnAgent(
time_step_spec= train_env.time_step_spec(),
action_spec= train_env.action_spec(),
q_network= q_net,
optimizer= optimizer,
target_update_period= TARGET_COLLECT_STEPS_PER_ITERATION,
td_errors_loss_fn= common.element_wise_squared_loss,
gamma= 0.99,
train_step_counter= global_step, #.
epsilon_greedy= lambda: epsilon_fn(global_step)) #.
agent.initialize()
3.2.4 Metrics for Evaluation
We need to measure the effectiveness of a model trained with reinforcement learning. The loss function of the internal Q-network is not a good measure because it measures how close the Q-network was fit to the collected data and does not indicate how effective the DQN is in maximizing rewards. The method we use here is the average reward received over several episodes. We also make use of built-in metrics:
train_metrics = [
tf_metrics.NumberOfEpisodes(),
tf_metrics.EnvironmentSteps(),
tf_metrics.AverageReturnMetric(),
tf_metrics.AverageEpisodeLengthMetric(),
]
def compute_avg_return(env, pol, num_episodes=10):
print(f"... computing avg return with num_episodes={num_episodes}")
total_return = 0.0
for _ in range(num_episodes):
tstep = env.reset()
episode_return = 0.0
while not tstep.is_last():
pstep = pol.action(tstep)
tstep = env.step(pstep.action)
episode_return += tstep.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
3.2.5.1 Replay buffer
The DQN works by training a neural network to predict the Q-values for every state of the environment. Because the DQN does not have pre-existing training data this data has to be accumulated as the agent and environment evolves through steps. The accumulated data is stored in the replay buffer. The replay buffer is a First-In-First-Out (FIFO) buffer. Only the most recent episodes are stored, older episode data rolls off the queue as the queue accumulates new data.
REPLAY_BUFFER_MAX_LENGTH = 100_000 #.
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec= agent.collect_data_spec,
batch_size= train_env.batch_size,
max_length= REPLAY_BUFFER_MAX_LENGTH)
replay_buffer_observer = replay_buffer.add_batch
Training cannot begin on an empty replay buffer
. Next, we fill the replay_buffer
with random data, but first we setup a quick object to show progress:
class Progress: #class to show progress
def __init__(self, total):
self.counter = 0
self.total = total
def __call__(self, trajectory):
if not trajectory.is_boundary():
self.counter += 1
if self.counter % 100 == 0:
print("\r{}/{}".format(self.counter, self.total), end="")
Now we run the init-driver
to collect an initial INITIAL_COLLECT_STEPS
to prime the replay_buffer
with some random data. Note that we will only do this when we start with a new global training session. If we resume an existing training session, this cell should be commented out:
# INITIAL_COLLECT_STEPS = 20_000
# initial_collect_policy = RandomTFPolicy(train_env.time_step_spec(),
# train_env.action_spec())
# init_driver = DynamicStepDriver(
# env= train_env,
# policy= initial_collect_policy,
# observers= [replay_buffer.add_batch, Progress(20_000)],
# num_steps= INITIAL_COLLECT_STEPS
# )
# final_time_step, final_policy_state = init_driver.run()
BATCH_SIZE = 64
dataset = replay_buffer.as_dataset(
sample_batch_size= BATCH_SIZE,
num_steps= 2,
num_parallel_calls= 3).prefetch(3)
collect_driver = DynamicStepDriver(
env= train_env,
policy= agent.collect_policy,
observers= [replay_buffer_observer] + train_metrics,
num_steps= COLLECT_STEPS_PER_ITERATION)
3.2.6 Mechanism to continue training over multiple sessions
Running on Google Colab, as we are, we don't have enough time to ensure sufficient training. Colab will terminate our process if we let it run for too long. We need a mechanism to resume training from a previous checkpoint so that we can pickup and continue with training. This way we can splice together many small training sessions without losing previous progress. First, we need to setup a Checkpointer and PolicySaver.
3.2.6.1 Setup Checkpointer and PolicySaver
In addition to setting up the Checkpointer and PolicySaver we will also setup two pickle files: One to store accumulated returns and another to store accumulated losses:
# Setup Checkpointer and PolicySaver
model = 'policy1'
from pathlib import Path
# Checkpointer
checkpoint_dir = Path(f'{base_dir}/{model}/checkpoint'); #print(checkpoint_dir)
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=agent,
policy=agent.policy,
replay_buffer=replay_buffer,
global_step=global_step #.
)
# PolicySaver
from tf_agents.policies import policy_saver
policy_dir = Path(f'{base_dir}/{model}/policy'); #print(policy_dir)
policy_saver = policy_saver.PolicySaver(policy=agent.policy)
# Returns
returns_pkl_file = Path(f'{base_dir}/{model}/returns.pkl'); #print(returns_pkl_file)
# Losses
losses_pkl_file = Path(f'{base_dir}/{model}/losses.pkl'); #print(losses_pkl_file);
# Restore checkpoint
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
tf.print(global_step)
# Restore policy
# saved_policy = tf.compat.v2.saved_model.load(str(policy_dir))
# Restore returns
try:
with open(returns_pkl_file, "rb") as fp:
returns = pickle.load(fp); print(returns)
except FileNotFoundError:
print("... file does not exist yet")
# Restore losses
try:
with open(losses_pkl_file, "rb") as fp:
losses = pickle.load(fp); print(losses)
except FileNotFoundError:
print("... file does not exist yet")
# initialize returns and losses
try:
returns
except NameError:
returns = []
print("... initialized returns global variable")
try:
losses
except NameError:
losses = []
print("... initialized losses global variable")
# convert the main functions to TF functions
collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)
def train_agent(n_iterations, loss_interval, log_interval, eval_interval, n_eval_episodes):
time_step = None
policy_state = agent.collect_policy.get_initial_state(train_env.batch_size)
iterator = iter(dataset)
#eval agent's policy once before NEW training:
if len(losses) == 0:
avg_return = compute_avg_return(eval_env, agent.policy, n_eval_episodes) #comment out to save time with short trial runs
# avg_return = train_metrics[2].result().numpy() #comment out to save time with short trial runs
returns.append(avg_return)
for i in range(n_iterations):
#collect and save to replay_buffer:
time_step, policy_state = collect_driver.run(time_step, policy_state)
#sample batch from replay_buffer and update agent's network:
trajectories, buffer_info = next(iterator)
train_loss = agent.train(trajectories).loss
# losses.append(train_loss.numpy())
step = agent.train_step_counter.numpy(); #print(f'i={i}, step={step}: loss={train_loss}')
if step % loss_interval == 0: losses.append(train_loss.numpy())
if step % log_interval == 0: print(f'i={i+1}, step={step}: loss={train_loss:.5f}')
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, agent.policy, n_eval_episodes)
# avg_return = train_metrics[2].result().numpy()
print(f'step={step}: Average Return={avg_return}')
returns.append(avg_return)
print('returns:', returns)
#persist:
with open(returns_pkl_file, "wb") as fp: pickle.dump(returns, fp)
with open(losses_pkl_file, "wb") as fp: pickle.dump(losses, fp)
train_checkpointer.save(global_step)
policy_saver.save(str(policy_dir))
N_ITERATIONS = 9_000 #20_000 #per training session
LOSS_INTERVAL = 1
LOG_INTERVAL = 500
EVAL_INTERVAL = 5_000
NUM_EVAL_EPISODES = 5
train_agent(
n_iterations=N_ITERATIONS,
loss_interval=LOSS_INTERVAL,
log_interval=LOG_INTERVAL,
eval_interval=EVAL_INTERVAL,
n_eval_episodes=NUM_EVAL_EPISODES);
print(global_step)
evals = range(0, global_step.numpy()+1, EVAL_INTERVAL); print(f"evals: {evals}")
# list(evals)
# evals = range(0, num_iterations+1, eval_interval); print(f"evals: {evals}, \nlist(evals): {list(evals)}, \nreturns: {returns}")
evals = range(0, global_step.numpy()+1, EVAL_INTERVAL); print(f"evals: {evals}, \nlist(evals): {list(evals)}, \nreturns: {returns}")
plt.figure(figsize=(20,10))
plt.plot(list(evals), returns)
plt.scatter(list(evals), returns)
plt.xlabel('Iterations')
plt.ylabel('Average Return')
# plt.ylim(top=50)
# plt.plot(list(evals), returns)
# plt.scatter(list(evals), returns)
every = 1000*LOSS_INTERVAL
# its = range(0, global_step.numpy()); print(f"its: {its}, \nlist(its): {list(its)}, \nlosses: {losses}")
its = range(0, global_step.numpy(), every); print(f"its: {its}, \nlist(its): {list(its)}, \nlosses: {losses}")
plt.figure(figsize=(20,10))
# plt.plot(losses)
plt.plot(its, losses[::every])
plt.scatter(its, losses[::every])
# plt.plot([e.numpy() for e in losses])
# plt.scatter(list(its), losses)
plt.xlabel('Iterations')
plt.ylabel('Loss')
def update_scene(num, frames, patch):
patch.set_data(frames[num])
return patch,
def plot_animation(frames, repeat=False, interval=40):
fig = plt.figure()
patch = plt.imshow(frames[0])
plt.axis('off')
anim = animation.FuncAnimation(
fig, update_scene, fargs=(frames, patch),
frames=len(frames), repeat=repeat, interval=interval)
plt.close()
return anim
env = eval_env
frames = []
def save_frames(trajectory):
global frames
frames.append(env.pyenv.envs[0].render(mode="rgb_array"))
prev_lives = env.pyenv.envs[0].ale.lives()
def reset_and_fire_on_life_lost(trajectory):
global prev_lives
lives = env.pyenv.envs[0].ale.lives()
if prev_lives != lives:
env.reset()
env.pyenv.envs[0].step(1)
prev_lives = lives
# access the deployment policy
watch_driver = DynamicStepDriver(
env,
agent.policy,
observers=[save_frames, reset_and_fire_on_life_lost, Progress(1000)],
num_steps=1000)
# num_steps=2000) #.
final_time_step, final_policy_state = watch_driver.run()
plot_animation(frames)
# image_path = os.path.join("PongPlayer_train.gif")
image_path = os.path.join("PongPlayer_eval.gif")
frame_images = [PIL.Image.fromarray(frame) for frame in frames[:500]]
frame_images[0].save(image_path, format='GIF',
append_images=frame_images[1:],
save_all=True,
duration=60,
# duration=120, #.
loop=0)
ls
pwd
# Set up a virtual display for rendering OpenAI gym environments.
!sudo apt-get install -y xvfb ffmpeg
!pip install -q 'gym==0.10.11'
!pip install -q 'imageio==2.4.0'
!pip install -q PILLOW
!pip install -q 'pyglet==1.3.2'
!pip install -q pyvirtualdisplay
import base64
import imageio
import IPython
import pyvirtualdisplay
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
def embed_mp4(filename):
"""Embeds an mp4 file in the notebook."""
video = open(filename,'rb').read()
b64 = base64.b64encode(video)
tag = '''
<video width="640" height="480" controls>
<source src="data:video/mp4;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>'''.format(b64.decode())
return IPython.display.HTML(tag)
def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
filename = filename + ".mp4"
with imageio.get_writer(filename, fps=fps) as video:
for _ in range(num_episodes):
tstep = eval_env.reset()
video.append_data(eval_py_env.render())
while not tstep.is_last():
pstep = policy.action(tstep)
tstep = eval_env.step(pstep.action)
video.append_data(eval_py_env.render())
return embed_mp4(filename)
create_policy_eval_video(agent.policy, "trained-agent")
ls