Commit c613e703 by 20210801063

Upload New File

parent 020f734a
from MDP import MDP
import numpy as np
import copy
import matplotlib.pyplot as plt
import numpy as np
def plot_value_and_policy(values, policy):
data = np.zeros((5, 5))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.title('Value')
for y in range(data.shape[0]):
for x in range(data.shape[1]):
data[y][x] = values[(x, y)]
plt.text(x + 0.5, y + 0.5, '%.4f' % data[y, x], horizontalalignment='center', verticalalignment='center', )
heatmap = plt.pcolor(data)
plt.gca().invert_yaxis()
plt.colorbar(heatmap)
plt.subplot(1, 2, 2)
plt.title('Policy')
for y in range(5):
for x in range(5):
for action in policy[(x, y)]:
if action == 'DRIBBLE_UP':
plt.annotate('', (x + 0.5, y), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
if action == 'DRIBBLE_DOWN':
plt.annotate('', (x + 0.5, y + 1), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
if action == 'DRIBBLE_RIGHT':
plt.annotate('', (x + 1, y + 0.5), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
if action == 'DRIBBLE_LEFT':
plt.annotate('', (x, y + 0.5), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
if action == 'SHOOT':
plt.text(x + 0.5, y + 0.5, action, horizontalalignment='center', verticalalignment='center', )
heatmap = plt.pcolor(data)
plt.gca().invert_yaxis()
plt.colorbar(heatmap)
plt.show()
class BellmanDPSolver(object):
def __init__(self, discountRate=0.9):
self.MDP = MDP()
self.discountRate = discountRate
self.initVs()
def initVs(self):
self.V = {}
self.policy = {}
for state in self.MDP.S:
self.V[state] = 0
self.policy[state] = np.array([0.5] * len(self.MDP.A))
def BellmanUpdate(self):
try:
copy_V = copy.deepcopy(self.V)
for state in self.MDP.S:
next_policy = []
nextValue = None
action_list = ["DRIBBLE_UP","DRIBBLE_DOWN","DRIBBLE_LEFT","DRIBBLE_RIGHT","SHOOT"];
for i in range(5):
action = action_list[i];
temp = 0
for nextState, prob in self.MDP.probNextStates(state, action).items():
temp += prob * (self.MDP.getRewards(state, action, nextState) + self.discountRate * copy_V[
nextState])
if not nextValue or nextValue == temp:
next_policy.append(action)
nextValue = temp
elif nextValue > temp:
continue
else:
next_policy = [action]
nextValue = temp
self.V[state] = nextValue
self.policy[state] = next_policy
return self.V, self.policy
except:
raise NotImplementedError
if __name__ == '__main__':
solution = BellmanDPSolver()
iter =1000
for i in range(iter):
values, policy = solution.BellmanUpdate()
print("Values : ", values)
print("Policy : ", policy)
print(iter)
plot_value_and_policy(values, policy)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment