Off-Policy Monte Carlo Prediction with Importance Sampling
Link to the Notebook
Off Policy Learning
By exploration-exploitation trade-off, the agent should take sub-optimal exploratory action by which the agent may receive less reward. One way of exploration is by using an epsilon-greedy policy, where the agent takes a nongreedy action with a small probability.
In an on-policy, improvement and evaluation are done on the policy which is used to select actions.
In off-policy, improvement and evaluation are done on a different policy from the one used to select actions. The policy learned is off the policy used for action selection while gathering episodes.
-
Target Policy : The value function of learning is based on . We want the target policy to be the optimal policy . The target policy will be used for action selection after the learning process is complete (deployment).
-
Behavior Policy : Behavior policy is used for action selection while gathering episodes to train the agent. This generally follows an exploratory policy.
Importance Sampling
We have a random variable sampled from behavior policy distribution . We want to estimate the expected value of wrt the target distribution ie: . The sample average will give the expected value of x under b .
where and is called the importance sampling ratio.
Let be a new random variable . Then, . Now we have expectation under instead of .
How do we estimate expectation from the data?
Compute the weighted sample average with importance sampling ratio as the weights.
Off Policy Monte Carlo Prediction with Importance Sampling
In Monte Carlo prediction, we estimate the value of each state by computing a sample average over returns starting from that state:
In off-policy, we are trying to estimate value under the target policy using returns following the behavior policy . We need to find the value of for each of the sampled returns. is the probability of trajectory under divided by the probability of trajectory under .
The probability of a trajectory under a policy can be given by:
where all actions are taken under the policy .
As this is a Markov process, we can split the probability terms into:
Incremental Implementation of Off-policy MC
Suppose we have a sequence of returns all starting in the same state and with corresponding weights , then the weighted average of returns:
and the cumulative sum of weights where .
The update rule: , and is arbitrary.
Here will be the importance sampling weight.
Greedy policy
When the target policy is greedy, it is deterministic and . So:
and can be updated with for each time step in the trajectory.
Off-Policy MC Control with Weighted Importance in Python
BlackJack Environment
import sys
import gym
import numpy as np
from collections import defaultdict
env = gym.make('Blackjack-v0')
print(vars(env), end='\n\n')
print(dir(env))
{'action_space': Discrete(2), 'observation_space': Tuple(Discrete(32), Discrete(11), Discrete(2)), 'natural': False, ...}
print(env.observation_space)
print(env.action_space)
Tuple(Discrete(32), Discrete(11), Discrete(2))
Discrete(2)
random_state = env.reset()
print('Random State', random_state)
random_action = env.action_space.sample()
print('Random Action', random_action)
Random State (18, 2, False)
Random Action 1
Generate Random Episodes
num_episodes = 3
for i in range(num_episodes):
print('Episode : ', i+1)
state = env.reset()
step = 0
while True:
step +=1
action = env.action_space.sample()
print('Step = {}\t State = {}\t Action Taken = {}'.format(step, state, action))
state, reward, done, info = env.step(action)
if done:
print('Game Ended...')
if reward > 0: print('Agent Won!\n')
else: print('Agent Lost!\n')
break
Training
def random_policy(nA):
A = np.ones(nA, dtype=float) / nA
def policy_fn(observation):
return A
return policy_fn
def greedy_policy(Q):
def policy_fn(state):
A = np.zeros_like(Q[state], dtype=float)
best_action = np.argmax(Q[state])
A[best_action] = 1.0
return A
return policy_fn
def mc_off_policy(env, num_episodes, behavior_policy, max_time=100, discount_factor=1.0):
Q = defaultdict(lambda:np.zeros(env.action_space.n))
C = defaultdict(lambda:np.zeros(env.action_space.n))
target_policy = greedy_policy(Q)
for i_episode in range(1, num_episodes+1):
if i_episode % 1000 == 0:
print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
sys.stdout.flush()
episode = []
state = env.reset()
for t in range(max_time):
probs = behavior_policy(state)
action = np.random.choice(np.arange(len(probs)), p=probs)
next_state, reward, done, _ = env.step(action)
episode.append((state, action, reward))
if done:
break
state = next_state
G = 0.0
W = 1.0
for t in range(len(episode))[::-1]:
state, action, reward = episode[t]
G = discount_factor * G + reward
C[state][action] += W
Q[state][action] += (W / C[state][action]) * (G - Q[state][action])
if action != np.argmax(target_policy(state)):
break
W = W * 1./behavior_policy(state)[action]
return Q, target_policy
random_policy = random_policy(env.action_space.n)
Q, policy = mc_off_policy(env, num_episodes=500000, behavior_policy=random_policy)
Plot
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_policy(policy):
def get_Z(x, y, usable_ace):
if (x,y,usable_ace) in policy:
return policy[x,y,usable_ace]
else:
return 1
def get_figure(usable_ace, ax):
x_range = np.arange(11, 22)
y_range = np.arange(10, 0, -1)
X, Y = np.meshgrid(x_range, y_range)
Z = np.array([[get_Z(x,y,usable_ace) for x in x_range] for y in y_range])
surf = ax.imshow(Z, cmap=plt.get_cmap('Pastel2', 2), vmin=0, vmax=1, extent=[10.5, 21.5, 0.5, 10.5])
plt.xticks(x_range)
plt.yticks(y_range)
plt.gca().invert_yaxis()
ax.set_xlabel('Player\'s Current Sum')
ax.set_ylabel('Dealer\'s Showing Card')
ax.grid(color='w', linestyle='-', linewidth=1)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(surf, ticks=[0,1], cax=cax)
cbar.ax.set_yticklabels(['0 (STICK)','1 (HIT)'])
fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(121)
ax.set_title('Usable Ace')
get_figure(True, ax)
ax = fig.add_subplot(122)
ax.set_title('No Usable Ace')
get_figure(False, ax)
plt.show()
policy = dict((k,np.argmax(v)) for k, v in Q.items())
plot_policy(policy)