Commit 2ec84a35 by 20200318029

homework7

parent a37d8c4c
...@@ -182,9 +182,9 @@ class LinUCB(object): ...@@ -182,9 +182,9 @@ class LinUCB(object):
@return: @return:
action - true observed action for context action - true observed action for context
""" """
def get_reward(self, arm, action): def get_reward(self, arm, action, reward):
if arm == action: if arm == action:
return 1 return reward
return 0 return 0
""" """
...@@ -226,6 +226,7 @@ class bandit_evaluator(object): ...@@ -226,6 +226,7 @@ class bandit_evaluator(object):
self.bandit = None self.bandit = None
self.cum_rewards = 0 self.cum_rewards = 0
self.ctr_history = [] self.ctr_history = []
self.T = 1e-5
""" """
calc_ctr: calc_ctr:
...@@ -239,20 +240,21 @@ class bandit_evaluator(object): ...@@ -239,20 +240,21 @@ class bandit_evaluator(object):
@return: @return:
ctr - cumulative take-rate ctr - cumulative take-rate
""" """
def calc_ctr(self, x, action, t): def calc_ctr(self, x, action, reward, t):
assert t > 0 assert t > 0
pred_act = self.bandit.predict(x) pred_act = self.bandit.predict(x)
### todo ### todo
if pred_act == action: if pred_act == action:
self.cum_rewards += 1 self.cum_rewards += reward
ctr = self.cum_rewards / t self.T += 1
ctr = self.cum_rewards / self.T
self.ctr_history.append(ctr) self.ctr_history.append(ctr)
return ctr return ctr
# In[20]: # In[20]:
from utils import getData, getContext, getAction from utils import getData, getContext, getAction, getReward
""" """
...@@ -279,16 +281,17 @@ def train(file, steps, alpha, nArms, d): ...@@ -279,16 +281,17 @@ def train(file, steps, alpha, nArms, d):
for t in range(steps): for t in range(steps):
x = getContext(data, t) x = getContext(data, t)
action = getAction(data, t) action = getAction(data, t)
reward = getReward(data, t)
arm = bandit.predict(x) arm = bandit.predict(x)
reward = bandit.get_reward(arm, action) reward_ = bandit.get_reward(arm, action, reward)
bandit.arms[arm].update_arm(reward, x) bandit.arms[arm].update_arm(reward_, x)
if t > 0: # explore various alpha update methods to improve CTR if t > 0: # explore various alpha update methods to improve CTR
# bandit.arms[arm].update_alpha(method=2) # or method=2 # bandit.arms[arm].update_alpha(method=2) # or method=2
bandit.arms[arm].update_alpha(3, t) bandit.arms[arm].update_alpha(3, t)
if t > 0: # evaluate current bandit algorithm if t > 0: # evaluate current bandit algorithm
ctr = evaluator.calc_ctr(x, action, t) ctr = evaluator.calc_ctr(x, action, reward, t)
if t % 100 == 0: if t % 100 == 0:
print("Step:", t, end="") print("Step:", t, end="")
print(" | CTR: {0:.02f}%".format(ctr)) print(" | CTR: {0:.02f}%".format(ctr))
...@@ -299,7 +302,7 @@ def train(file, steps, alpha, nArms, d): ...@@ -299,7 +302,7 @@ def train(file, steps, alpha, nArms, d):
# In[21]: # In[21]:
file = "classification.txt" file = "dataset.txt"
steps = 10000 steps = 10000
alpha = .1 alpha = .1
nArms = 10 nArms = 10
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment