Commit 1b617916 by 20201219013

DeepRL. todo: config the env and refactor

parent f4152bc1
MIT License
Copyright (c) 2019 Viet Nguyen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""
import torch.nn as nn
class DeepQNetwork(nn.Module):
def __init__(self):
super(DeepQNetwork, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
self.fc2 = nn.Linear(512, 2)
self._create_weights()
def _create_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.uniform_(m.weight, -0.01, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, input):
output = self.conv1(input)
output = self.conv2(output)
output = self.conv3(output)
output = output.view(output.size(0), -1)
output = self.fc1(output)
output = self.fc2(output)
return output
"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""
from itertools import cycle
from numpy.random import randint
from pygame import Rect, init, time, display
from pygame.event import pump
from pygame.image import load
from pygame.surfarray import array3d, pixels_alpha
from pygame.transform import rotate
import numpy as np
class FlappyBird(object):
init()
fps_clock = time.Clock()
screen_width = 288
screen_height = 512
screen = display.set_mode((screen_width, screen_height))
display.set_caption('Deep Q-Network Flappy Bird')
base_image = load('assets/sprites/base.png').convert_alpha()
background_image = load('assets/sprites/background-black.png').convert()
pipe_images = [rotate(load('assets/sprites/pipe-green.png').convert_alpha(), 180),
load('assets/sprites/pipe-green.png').convert_alpha()]
bird_images = [load('assets/sprites/redbird-upflap.png').convert_alpha(),
load('assets/sprites/redbird-midflap.png').convert_alpha(),
load('assets/sprites/redbird-downflap.png').convert_alpha()]
# number_images = [load('assets/sprites/{}.png'.format(i)).convert_alpha() for i in range(10)]
bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_images]
pipe_hitmask = [pixels_alpha(image).astype(bool) for image in pipe_images]
fps = 30
pipe_gap_size = 100
pipe_velocity_x = -4
# parameters for bird
min_velocity_y = -8
max_velocity_y = 10
downward_speed = 1
upward_speed = -9
bird_index_generator = cycle([0, 1, 2, 1])
def __init__(self):
self.iter = self.bird_index = self.score = 0
self.bird_width = self.bird_images[0].get_width()
self.bird_height = self.bird_images[0].get_height()
self.pipe_width = self.pipe_images[0].get_width()
self.pipe_height = self.pipe_images[0].get_height()
self.bird_x = int(self.screen_width / 5)
self.bird_y = int((self.screen_height - self.bird_height) / 2)
self.base_x = 0
self.base_y = self.screen_height * 0.79
self.base_shift = self.base_image.get_width() - self.background_image.get_width()
pipes = [self.generate_pipe(), self.generate_pipe()]
pipes[0]["x_upper"] = pipes[0]["x_lower"] = self.screen_width
pipes[1]["x_upper"] = pipes[1]["x_lower"] = self.screen_width * 1.5
self.pipes = pipes
self.current_velocity_y = 0
self.is_flapped = False
def generate_pipe(self):
x = self.screen_width + 10
gap_y = randint(2, 10) * 10 + int(self.base_y / 5)
return {"x_upper": x, "y_upper": gap_y - self.pipe_height, "x_lower": x, "y_lower": gap_y + self.pipe_gap_size}
def is_collided(self):
# Check if the bird touch ground
if self.bird_height + self.bird_y + 1 >= self.base_y:
return True
bird_bbox = Rect(self.bird_x, self.bird_y, self.bird_width, self.bird_height)
pipe_boxes = []
for pipe in self.pipes:
pipe_boxes.append(Rect(pipe["x_upper"], pipe["y_upper"], self.pipe_width, self.pipe_height))
pipe_boxes.append(Rect(pipe["x_lower"], pipe["y_lower"], self.pipe_width, self.pipe_height))
# Check if the bird's bounding box overlaps to the bounding box of any pipe
if bird_bbox.collidelist(pipe_boxes) == -1:
return False
for i in range(2):
cropped_bbox = bird_bbox.clip(pipe_boxes[i])
min_x1 = cropped_bbox.x - bird_bbox.x
min_y1 = cropped_bbox.y - bird_bbox.y
min_x2 = cropped_bbox.x - pipe_boxes[i].x
min_y2 = cropped_bbox.y - pipe_boxes[i].y
if np.any(self.bird_hitmask[self.bird_index][min_x1:min_x1 + cropped_bbox.width,
min_y1:min_y1 + cropped_bbox.height] * self.pipe_hitmask[i][min_x2:min_x2 + cropped_bbox.width,
min_y2:min_y2 + cropped_bbox.height]):
return True
return False
def next_frame(self, action):
pump()
reward = 0.1
terminal = False
# Check input action
if action == 1:
self.current_velocity_y = self.upward_speed
self.is_flapped = True
# Update score
bird_center_x = self.bird_x + self.bird_width / 2
for pipe in self.pipes:
pipe_center_x = pipe["x_upper"] + self.pipe_width / 2
if pipe_center_x < bird_center_x < pipe_center_x + 5:
self.score += 1
reward = 1
break
# Update index and iteration
if (self.iter + 1) % 3 == 0:
self.bird_index = next(self.bird_index_generator)
self.iter = 0
self.base_x = -((-self.base_x + 100) % self.base_shift)
# Update bird's position
if self.current_velocity_y < self.max_velocity_y and not self.is_flapped:
self.current_velocity_y += self.downward_speed
if self.is_flapped:
self.is_flapped = False
self.bird_y += min(self.current_velocity_y, self.bird_y - self.current_velocity_y - self.bird_height)
if self.bird_y < 0:
self.bird_y = 0
# Update pipes' position
for pipe in self.pipes:
pipe["x_upper"] += self.pipe_velocity_x
pipe["x_lower"] += self.pipe_velocity_x
# Update pipes
if 0 < self.pipes[0]["x_lower"] < 5:
self.pipes.append(self.generate_pipe())
if self.pipes[0]["x_lower"] < -self.pipe_width:
del self.pipes[0]
if self.is_collided():
terminal = True
reward = -1
self.__init__()
# Draw everything
self.screen.blit(self.background_image, (0, 0))
self.screen.blit(self.base_image, (self.base_x, self.base_y))
self.screen.blit(self.bird_images[self.bird_index], (self.bird_x, self.bird_y))
for pipe in self.pipes:
self.screen.blit(self.pipe_images[0], (pipe["x_upper"], pipe["y_upper"]))
self.screen.blit(self.pipe_images[1], (pipe["x_lower"], pipe["y_lower"]))
image = array3d(display.get_surface())
display.update()
self.fps_clock.tick(self.fps)
return image, reward, terminal
"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""
import cv2
import numpy as np
def pre_processing(image, width, height):
image = cv2.cvtColor(cv2.resize(image, (width, height)), cv2.COLOR_BGR2GRAY)
_, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
return image[None, :, :].astype(np.float32)
"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""
import argparse
import torch
from src.deep_q_network import DeepQNetwork
from src.flappy_bird import FlappyBird
from src.utils import pre_processing
def get_args():
parser = argparse.ArgumentParser(
"""Implementation of Deep Q Network to play Flappy Bird""")
parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
parser.add_argument("--saved_path", type=str, default="trained_models")
args = parser.parse_args()
return args
def test(opt):
if torch.cuda.is_available():
torch.cuda.manual_seed(123)
else:
torch.manual_seed(123)
if torch.cuda.is_available():
model = torch.load("{}/flappy_bird".format(opt.saved_path))
else:
model = torch.load("{}/flappy_bird".format(opt.saved_path), map_location=lambda storage, loc: storage)
model.eval()
game_state = FlappyBird()
image, reward, terminal = game_state.next_frame(0)
image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
image = torch.from_numpy(image)
if torch.cuda.is_available():
model.cuda()
image = image.cuda()
state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
while True:
prediction = model(state)[0]
action = torch.argmax(prediction).item()
next_image, reward, terminal = game_state.next_frame(action)
next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
opt.image_size)
next_image = torch.from_numpy(next_image)
if torch.cuda.is_available():
next_image = next_image.cuda()
next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
state = next_state
if __name__ == "__main__":
opt = get_args()
test(opt)
"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""
import argparse
import os
import shutil
from random import random, randint, sample
import numpy as np
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from src.deep_q_network import DeepQNetwork
from src.flappy_bird import FlappyBird
from src.utils import pre_processing
def get_args():
parser = argparse.ArgumentParser(
"""Implementation of Deep Q Network to play Flappy Bird""")
parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
parser.add_argument("--batch_size", type=int, default=32, help="The number of images per batch")
parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
parser.add_argument("--lr", type=float, default=1e-6)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--initial_epsilon", type=float, default=0.1)
parser.add_argument("--final_epsilon", type=float, default=1e-4)
parser.add_argument("--num_iters", type=int, default=2000000)
parser.add_argument("--replay_memory_size", type=int, default=50000,
help="Number of epoches between testing phases")
parser.add_argument("--log_path", type=str, default="tensorboard")
parser.add_argument("--saved_path", type=str, default="trained_models")
args = parser.parse_args()
return args
def train(opt):
if torch.cuda.is_available():
torch.cuda.manual_seed(123)
else:
torch.manual_seed(123)
model = DeepQNetwork()
if os.path.isdir(opt.log_path):
shutil.rmtree(opt.log_path)
os.makedirs(opt.log_path)
writer = SummaryWriter(opt.log_path)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
criterion = nn.MSELoss()
game_state = FlappyBird()
image, reward, terminal = game_state.next_frame(0)
image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
image = torch.from_numpy(image)
if torch.cuda.is_available():
model.cuda()
image = image.cuda()
state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
replay_memory = []
iter = 0
while iter < opt.num_iters:
prediction = model(state)[0]
# Exploration or exploitation
epsilon = opt.final_epsilon + (
(opt.num_iters - iter) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_iters)
u = random()
random_action = u <= epsilon
if random_action:
print("Perform a random action")
action = randint(0, 1)
else:
action = torch.argmax(prediction)[0]
next_image, reward, terminal = game_state.next_frame(action)
next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
opt.image_size)
next_image = torch.from_numpy(next_image)
if torch.cuda.is_available():
next_image = next_image.cuda()
next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
replay_memory.append([state, action, reward, next_state, terminal])
if len(replay_memory) > opt.replay_memory_size:
del replay_memory[0]
batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)
state_batch = torch.cat(tuple(state for state in state_batch))
action_batch = torch.from_numpy(
np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))
reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
next_state_batch = torch.cat(tuple(state for state in next_state_batch))
if torch.cuda.is_available():
state_batch = state_batch.cuda()
action_batch = action_batch.cuda()
reward_batch = reward_batch.cuda()
next_state_batch = next_state_batch.cuda()
current_prediction_batch = model(state_batch)
next_prediction_batch = model(next_state_batch)
y_batch = torch.cat(
tuple(reward if terminal else reward + opt.gamma * torch.max(prediction) for reward, terminal, prediction in
zip(reward_batch, terminal_batch, next_prediction_batch)))
q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
optimizer.zero_grad()
# y_batch = y_batch.detach()
loss = criterion(q_value, y_batch)
loss.backward()
optimizer.step()
state = next_state
iter += 1
print("Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(
iter + 1,
opt.num_iters,
action,
loss,
epsilon, reward, torch.max(prediction)))
writer.add_scalar('Train/Loss', loss, iter)
writer.add_scalar('Train/Epsilon', epsilon, iter)
writer.add_scalar('Train/Reward', reward, iter)
writer.add_scalar('Train/Q-value', torch.max(prediction), iter)
if (iter+1) % 1000000 == 0:
torch.save(model, "{}/flappy_bird_{}".format(opt.saved_path, iter+1))
torch.save(model, "{}/flappy_bird".format(opt.saved_path))
if __name__ == "__main__":
opt = get_args()
train(opt)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment