Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
R
RL_Homework
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20210801063
RL_Homework
Commits
b05d1457
Commit
b05d1457
authored
Dec 05, 2021
by
20210801063
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Delete BellmanDPBase.py
parent
c613e703
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
98 deletions
+0
-98
BellmanDPBase.py
+0
-98
No files found.
BellmanDPBase.py
deleted
100644 → 0
View file @
c613e703
from
MDP
import
MDP
import
numpy
as
np
import
copy
import
matplotlib.pyplot
as
plt
import
numpy
as
np
def
plot_value_and_policy
(
values
,
policy
):
data
=
np
.
zeros
((
5
,
5
))
plt
.
figure
(
figsize
=
(
12
,
4
))
plt
.
subplot
(
1
,
2
,
1
)
plt
.
title
(
'Value'
)
for
y
in
range
(
data
.
shape
[
0
]):
for
x
in
range
(
data
.
shape
[
1
]):
data
[
y
][
x
]
=
values
[(
x
,
y
)]
plt
.
text
(
x
+
0.5
,
y
+
0.5
,
'
%.4
f'
%
data
[
y
,
x
],
horizontalalignment
=
'center'
,
verticalalignment
=
'center'
,
)
heatmap
=
plt
.
pcolor
(
data
)
plt
.
gca
()
.
invert_yaxis
()
plt
.
colorbar
(
heatmap
)
plt
.
subplot
(
1
,
2
,
2
)
plt
.
title
(
'Policy'
)
for
y
in
range
(
5
):
for
x
in
range
(
5
):
for
action
in
policy
[(
x
,
y
)]:
if
action
==
'DRIBBLE_UP'
:
plt
.
annotate
(
''
,
(
x
+
0.5
,
y
),
(
x
+
0.5
,
y
+
0.5
),
arrowprops
=
{
'width'
:
0.1
})
if
action
==
'DRIBBLE_DOWN'
:
plt
.
annotate
(
''
,
(
x
+
0.5
,
y
+
1
),
(
x
+
0.5
,
y
+
0.5
),
arrowprops
=
{
'width'
:
0.1
})
if
action
==
'DRIBBLE_RIGHT'
:
plt
.
annotate
(
''
,
(
x
+
1
,
y
+
0.5
),
(
x
+
0.5
,
y
+
0.5
),
arrowprops
=
{
'width'
:
0.1
})
if
action
==
'DRIBBLE_LEFT'
:
plt
.
annotate
(
''
,
(
x
,
y
+
0.5
),
(
x
+
0.5
,
y
+
0.5
),
arrowprops
=
{
'width'
:
0.1
})
if
action
==
'SHOOT'
:
plt
.
text
(
x
+
0.5
,
y
+
0.5
,
action
,
horizontalalignment
=
'center'
,
verticalalignment
=
'center'
,
)
heatmap
=
plt
.
pcolor
(
data
)
plt
.
gca
()
.
invert_yaxis
()
plt
.
colorbar
(
heatmap
)
plt
.
show
()
class
BellmanDPSolver
(
object
):
def
__init__
(
self
,
discountRate
=
0.9
):
self
.
MDP
=
MDP
()
self
.
discountRate
=
discountRate
self
.
initVs
()
def
initVs
(
self
):
self
.
V
=
{}
self
.
policy
=
{}
for
state
in
self
.
MDP
.
S
:
self
.
V
[
state
]
=
0
self
.
policy
[
state
]
=
np
.
array
([
0.5
]
*
len
(
self
.
MDP
.
A
))
def
BellmanUpdate
(
self
):
try
:
copy_V
=
copy
.
deepcopy
(
self
.
V
)
for
state
in
self
.
MDP
.
S
:
next_policy
=
[]
nextValue
=
None
action_list
=
[
"DRIBBLE_UP"
,
"DRIBBLE_DOWN"
,
"DRIBBLE_LEFT"
,
"DRIBBLE_RIGHT"
,
"SHOOT"
];
for
i
in
range
(
5
):
action
=
action_list
[
i
];
temp
=
0
for
nextState
,
prob
in
self
.
MDP
.
probNextStates
(
state
,
action
)
.
items
():
temp
+=
prob
*
(
self
.
MDP
.
getRewards
(
state
,
action
,
nextState
)
+
self
.
discountRate
*
copy_V
[
nextState
])
if
not
nextValue
or
nextValue
==
temp
:
next_policy
.
append
(
action
)
nextValue
=
temp
elif
nextValue
>
temp
:
continue
else
:
next_policy
=
[
action
]
nextValue
=
temp
self
.
V
[
state
]
=
nextValue
self
.
policy
[
state
]
=
next_policy
return
self
.
V
,
self
.
policy
except
:
raise
NotImplementedError
if
__name__
==
'__main__'
:
solution
=
BellmanDPSolver
()
iter
=
1000
for
i
in
range
(
iter
):
values
,
policy
=
solution
.
BellmanUpdate
()
print
(
"Values : "
,
values
)
print
(
"Policy : "
,
policy
)
print
(
iter
)
plot_value_and_policy
(
values
,
policy
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment