# coding=utf-8
from MDP import MDP
import numpy as np

import matplotlib.pyplot as plt


def plot_value_and_policy(values, policy):
    data = np.zeros((5, 5))

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.title('Value')
    for y in range(data.shape[0]):
        for x in range(data.shape[1]):
            data[y][x] = values[(x, y)]
            plt.text(x + 0.5, y + 0.5, '%.4f' % data[y, x], horizontalalignment='center', verticalalignment='center', )

    heatmap = plt.pcolor(data)
    plt.gca().invert_yaxis()
    plt.colorbar(heatmap)

    plt.subplot(1, 2, 2)
    plt.title('Policy')
    for y in range(5):
        for x in range(5):
            for action in policy[(x, y)]:
                if action == 'DRIBBLE_UP':
                    plt.annotate('', (x + 0.5, y), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
                if action == 'DRIBBLE_DOWN':
                    plt.annotate('', (x + 0.5, y + 1), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
                if action == 'DRIBBLE_RIGHT':
                    plt.annotate('', (x + 1, y + 0.5), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
                if action == 'DRIBBLE_LEFT':
                    plt.annotate('', (x, y + 0.5), (x + 0.5, y + 0.5), arrowprops={'width': 0.1})
                if action == 'SHOOT':
                    plt.text(x + 0.5, y + 0.5, action, horizontalalignment='center', verticalalignment='center', )

    heatmap = plt.pcolor(data)
    plt.gca().invert_yaxis()
    plt.colorbar(heatmap)
    plt.show()


class BellmanDPSolver(object):
    def __init__(self, discountRate=0.9):
        self.MDP = MDP()
        self.discountRate = discountRate
        self.initVs()

    def initVs(self):
        self.V = {}
        self.policy = {}
        for state in self.MDP.S:
            self.V[state] = 0
            self.policy[state] = np.array([0.5] * len(self.MDP.A))

    def BellmanUpdate(self):
        # state一共27种,我们每一轮要更新的是 V[state]
        # nextState = self.MDP.probNextStates((2, 2), self.MDP.A)
        # print(nextState)
        # next_V = [0.0] * len(self.V)

        updatedStateValue = self.V.copy()

        for state in self.MDP.S:
            currValue = self.V[state]  # state: position
            tmp_V = [0.0] * len(self.MDP.A)  # compute four next values
            for idx, action in enumerate(self.MDP.A):
                transitions = self.MDP.probNextStates(state, action)
                for newState, prob in transitions.items():  #这里计算value的方式和下面action-value相同, 因为状态转移是通过action给出的, 并没有更直接的状态转移矩阵
                    reward = self.MDP.getRewards(state, action, newState)
                    tmp_V[idx] += prob * 1.0 * (reward + self.discountRate * self.V[newState])

                updatedStateValue[state] = np.max(tmp_V)
        self.V = updatedStateValue

        policy = {}
        for state in self.MDP.S:
            action_value = np.zeros(len(self.MDP.A))
            for idx, action in enumerate(self.MDP.A):
                transitions = self.MDP.probNextStates(state, action)
                for new_state, prob in transitions.items():
                    reward = self.MDP.getRewards(state, action, new_state)
                    action_value[idx] += prob * 1.0 * (reward + self.discountRate * self.V[new_state])
            self.policy[state] = np.zeros((len(self.MDP.A))) # 初始化为0
            max_actions = np.argwhere(action_value == np.amax(action_value)).flatten().tolist()
            max_actions = np.sort(max_actions)
            prob_action = 1. / len(max_actions)  # 这里直接采用贪心平分policy, 并没有使用epsilon greedy
            self.policy[state][max_actions] = prob_action  # multi assignments, 均分概率
            policy[state] = np.array(self.MDP.A)[max_actions].tolist()  # convert index np array to list
        return self.V, policy

if __name__ == '__main__':
    solution = BellmanDPSolver()
    for i in range(100):
        values, policy = solution.BellmanUpdate()
    print("Values : ", values)
    print("Policy : ", policy)
    plot_value_and_policy(values, policy)