Skip to content

Commit 6186bb3

Browse files
committed
first commit
0 parents  commit 6186bb3

31 files changed

+638
-0
lines changed

.gitignore

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
##### FILE TYPES #####
2+
3+
*.tar
4+
*.pt
5+
*.DS_Store
6+
*.pyc
7+
*.mp4
8+
*.zip
9+
10+
##### DIRECTORIES #####
11+
12+
./pong-v0
13+
./breakout-v0
14+
./spaceinvaders-v0
15+
16+
./overfit-pong-v0
17+
./overfit-breakout-v0
18+
./overfit-spaceinvaders-v0
19+
20+
__pycache__
21+
.ipynb_checkpoints/*
22+
23+
24+
##### NAMES #####
25+
26+
*ubyte

README.md

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
Visualizing and Understanding Atari Agents
2+
=======
3+
Sam Greydanus. October 2017. MIT License.
4+
5+
Oregon State University College of Engineering. [Explainable AI Project](http://twitter.com/DARPA/status/872547502616182785). Supported by DARPA.
6+
7+
_Written in PyTorch_
8+
9+
Strong agents
10+
--------
11+
12+
![breakout-tunneling.gif](static/breakout-tunneling.gif)
13+
![pong-killshot.gif](static/pong-killshot.gif)
14+
![spaceinv-aiming.gif](static/spaceinv-aiming.gif)
15+
16+
Overfit agents
17+
--------
18+
* WITHOUT saliency:
19+
* overfit agent: https://youtu.be/TgTpF-EXPwc
20+
* control agent: https://youtu.be/i3Br2PzE49I
21+
* WITH saliency:
22+
* overfit agent: https://youtu.be/eeXLUI73RTo
23+
* control agent: https://youtu.be/xXGC6CQW97E
24+
25+
Learning
26+
--------
27+
![breakout-tunneling.gif](static/breakout-learning-2000.gif)
28+
29+
About
30+
--------
31+
Code for results in the paper [Visualizing and Understanding Atari Agents](https://arxiv.org/).
32+
33+
To do a quick comparison of Jacobian vs. Ours, check out [this Jupyter notebook](https://nbviewer.jupyter.org/github/greydanus/visualize_atari/blob/master/jacobian-vs-perturbation.ipynb)
34+
35+
**Abstract.** Deep reinforcement learning (deep RL) agents have achieved remarkable success in a broad range of game-playing and continuous control tasks. While these agents are effective at maximizing rewards, it is often unclear what strategies they use to do so. In this paper, we take a step toward explaining deep RL agents through a case study in three Atari 2600 environments. In particular, we focus on understanding agents in terms of their visual attentional patterns during decision making. To this end, we introduce a method for generating rich saliency maps and use it to explain 1) what strong agents attend to 2) whether agents are making decisions for the right or wrong reasons, and 3) how agents evolve during the learning phase. We also test our method on non-expert human subjects and find that it improves their ability to reason about these agents. Our techniques are general and, though we focus on Atari, our long-term objective is to produce tools that explain any deep RL policy.
36+
37+
Pretrained models
38+
--------
39+
Trained models were obtained using [this repo](https://github.com/greydanus/baby-a3c) (default hyperparameters).
40+
1. Download from [https://goo.gl/fqwJDB](https://goo.gl/fqwJDB)
41+
2. Unzip the file in this directory
42+
43+
Dependencies
44+
--------
45+
All code is written in Python 3.6. You will need:
46+
47+
* NumPy
48+
* SciPy
49+
* Matplotlib
50+
* [PyTorch 0.2](http://pytorch.org/): easier to write and debug than TensorFlow :)
51+
* [Jupyter](https://jupyter.org/)

__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .saliency import *
2+
from .rollout import *
3+
from .make_movie import *
4+
from .policy import *
5+
from .overfit_atari import *

jacobian-vs-perturbation.ipynb

+246
Large diffs are not rendered by default.

make_movie.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License
2+
3+
from __future__ import print_function
4+
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously
5+
6+
import matplotlib.pyplot as plt
7+
import matplotlib as mpl ; mpl.use("Agg")
8+
import matplotlib.animation as manimation
9+
10+
import gym, os, sys, time, argparse
11+
12+
sys.path.append('..')
13+
from visualize_atari import *
14+
15+
def make_movie(env_name, checkpoint='*.tar', num_frames=20, first_frame=0, resolution=75, \
16+
save_dir='./movies/', density=5, radius=5, prefix='default', overfit_mode=False):
17+
18+
# set up dir variables and environment
19+
load_dir = '{}{}/'.format('overfit-' if overfit_mode else '', env_name.lower())
20+
meta = get_env_meta(env_name)
21+
env = gym.make(env_name) if not overfit_mode else OverfitAtari(env_name, load_dir+'expert/', seed=0) # make a seeded env
22+
23+
# set up agent
24+
model = NNPolicy(channels=1, num_actions=env.action_space.n)
25+
model.try_load(load_dir, checkpoint=checkpoint)
26+
27+
# get a rollout of the policy
28+
movie_title = "{}-{}-{}.mp4".format(prefix, num_frames, env_name.lower())
29+
print('\tmaking movie "{}" using checkpoint at {}{}'.format(movie_title, load_dir, checkpoint))
30+
max_ep_len = first_frame + num_frames + 1
31+
torch.manual_seed(0)
32+
history = rollout(model, env, max_ep_len=max_ep_len)
33+
print()
34+
35+
# make the movie!
36+
start = time.time()
37+
FFMpegWriter = manimation.writers['ffmpeg']
38+
metadata = dict(title=movie_title, artist='greydanus', comment='atari-saliency-video')
39+
writer = FFMpegWriter(fps=8, metadata=metadata)
40+
41+
prog = '' ; total_frames = len(history['ins'])
42+
f = plt.figure(figsize=[6, 6*1.3], dpi=resolution)
43+
with writer.saving(f, save_dir + movie_title, resolution):
44+
for i in range(num_frames):
45+
ix = first_frame+i
46+
if ix < total_frames: # prevent loop from trying to process a frame ix greater than rollout length
47+
frame = history['ins'][ix].squeeze().copy()
48+
actor_saliency = score_frame(model, history, ix, radius, density, interp_func=occlude, mode='actor')
49+
critic_saliency = score_frame(model, history, ix, radius, density, interp_func=occlude, mode='critic')
50+
51+
frame = saliency_on_atari_frame(actor_saliency, frame, fudge_factor=meta['actor_ff'], channel=2)
52+
frame = saliency_on_atari_frame(critic_saliency, frame, fudge_factor=meta['critic_ff'], channel=0)
53+
54+
plt.imshow(frame) ; plt.title(env_name.lower(), fontsize=15)
55+
writer.grab_frame() ; f.clear()
56+
57+
tstr = time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start))
58+
print('\ttime: {} | progress: {:.1f}%'.format(tstr, 100*i/min(num_frames, total_frames)), end='\r')
59+
print('\nfinished.')
60+
61+
# user might also want to access make_movie function from some other script
62+
if __name__ == '__main__':
63+
64+
parser = argparse.ArgumentParser(description=None)
65+
parser.add_argument('-e', '--env', default='Breakout-v0', type=str, help='gym environment')
66+
parser.add_argument('-d', '--density', default=5, type=int, help='density of grid of gaussian blurs')
67+
parser.add_argument('-r', '--radius', default=5, type=int, help='radius of gaussian blur')
68+
parser.add_argument('-f', '--num_frames', default=20, type=int, help='number of frames in movie')
69+
parser.add_argument('-i', '--first_frame', default=150, type=int, help='index of first frame')
70+
parser.add_argument('-dpi', '--resolution', default=75, type=int, help='resolution (dpi)')
71+
parser.add_argument('-s', '--save_dir', default='./movies/', type=str, help='dir to save agent logs and checkpoints')
72+
parser.add_argument('-p', '--prefix', default='default', type=str, help='prefix to help make video name unique')
73+
parser.add_argument('-c', '--checkpoint', default='*.tar', type=str, help='checkpoint name (in case there is more than one')
74+
parser.add_argument('-o', '--overfit_mode', default=False, type=bool, help='analyze an overfit environment (see paper)')
75+
args = parser.parse_args()
76+
77+
make_movie(args.env, args.checkpoint, args.num_frames, args.first_frame, args.resolution,
78+
args.save_dir, args.density, args.radius, args.prefix, args.overfit_mode)

overfit_atari.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License
2+
3+
from __future__ import print_function
4+
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
5+
6+
import torch
7+
from torch.autograd import Variable
8+
import torch.nn.functional as F
9+
10+
import gym, sys
11+
import numpy as np
12+
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
13+
14+
sys.path.append('..')
15+
from visualize_atari import *
16+
17+
prepro = lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.
18+
19+
class OverfitAtari():
20+
def __init__(self, env_name, expert_dir, seed=0):
21+
self.atari = gym.make(env_name) ; self.atari.seed(seed)
22+
self.action_space = self.atari.action_space
23+
self.expert = NNPolicy(channels=1, num_actions=self.action_space.n)
24+
self.expert.try_load(expert_dir)
25+
self.cx = Variable(torch.zeros(1, 256)) # lstm memory vector
26+
self.hx = Variable(torch.zeros(1, 256)) # lstm activation vector
27+
28+
def seed(self, s):
29+
self.atari.seed(s) ; torch.manual_seed(s)
30+
31+
def reset(self):
32+
self.cx = Variable(torch.zeros(1, 256))
33+
self.hx = Variable(torch.zeros(1, 256))
34+
return self.atari.reset()
35+
36+
def step(self, action):
37+
state, reward, done, info = self.atari.step(action)
38+
39+
expert_state = torch.Tensor(prepro(state)) # get expert policy and incorporate it into environment
40+
_, logit, (hx, cx) = self.expert((Variable(expert_state.view(1,1,80,80)), (self.hx, self.cx)))
41+
self.hx, self.cx = Variable(hx.data), Variable(cx.data)
42+
43+
expert_action = int(F.softmax(logit).data.max(1)[1][0,0])
44+
target = torch.zeros(logit.size()) ; target[0,expert_action] = 1
45+
j = 72 ; k = 5
46+
expert_action = expert_action if False else np.random.randint(self.atari.action_space.n)
47+
for i in range(self.atari.action_space.n):
48+
state[37:41, j + k*i: j+1+k*i,:] = 250 if expert_action == i else 50
49+
return state, reward, done, target

policy.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License
2+
3+
from __future__ import print_function
4+
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
5+
6+
import torch
7+
from torch.autograd import Variable
8+
import torch.nn.functional as F
9+
import torch.nn as nn
10+
11+
import glob
12+
import numpy as np
13+
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
14+
15+
class NNPolicy(torch.nn.Module): # an actor-critic neural network
16+
def __init__(self, channels, num_actions):
17+
super(NNPolicy, self).__init__()
18+
self.conv1 = nn.Conv2d(channels, 32, 3, stride=2, padding=1)
19+
self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
20+
self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
21+
self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
22+
self.lstm = nn.LSTMCell(32 * 5 * 5, 256)
23+
self.critic_linear, self.actor_linear = nn.Linear(256, 1), nn.Linear(256, num_actions)
24+
25+
def forward(self, inputs):
26+
inputs, (hx, cx) = inputs
27+
x = F.elu(self.conv1(inputs))
28+
x = F.elu(self.conv2(x))
29+
x = F.elu(self.conv3(x))
30+
x = F.elu(self.conv4(x))
31+
x = x.view(-1, 32 * 5 * 5)
32+
hx, cx = self.lstm(x, (hx, cx))
33+
return self.critic_linear(hx), self.actor_linear(hx), (hx, cx)
34+
35+
def try_load(self, save_dir, checkpoint='*.tar'):
36+
paths = glob.glob(save_dir + checkpoint) ; step = 0
37+
if len(paths) > 0:
38+
ckpts = [int(s.split('.')[-2]) for s in paths]
39+
ix = np.argmax(ckpts) ; step = ckpts[ix]
40+
self.load_state_dict(torch.load(paths[ix]))
41+
print("\tno saved models") if step is 0 else print("\tloaded model: {}".format(paths[ix]))
42+
return step

rollout.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License
2+
3+
from __future__ import print_function
4+
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
5+
6+
import torch
7+
import torch.nn as nn
8+
from torch.autograd import Variable
9+
import torch.nn.functional as F
10+
11+
import numpy as np
12+
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
13+
14+
prepro = lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.
15+
16+
def rollout(model, env, max_ep_len=3e3, render=False):
17+
history = {'ins': [], 'logits': [], 'values': [], 'outs': [], 'hx': [], 'cx': []}
18+
19+
state = torch.Tensor(prepro(env.reset())) # get first state
20+
episode_length, epr, eploss, done = 0, 0, 0, False # bookkeeping
21+
hx, cx = Variable(torch.zeros(1, 256)), Variable(torch.zeros(1, 256))
22+
23+
while not done and episode_length <= max_ep_len:
24+
episode_length += 1
25+
value, logit, (hx, cx) = model((Variable(state.view(1,1,80,80)), (hx, cx)))
26+
hx, cx = Variable(hx.data), Variable(cx.data)
27+
prob = F.softmax(logit)
28+
29+
action = prob.max(1)[1].data # prob.multinomial().data[0] #
30+
obs, reward, done, expert_policy = env.step(action.numpy()[0])
31+
if render: env.render()
32+
state = torch.Tensor(prepro(obs)) ; epr += reward
33+
34+
# save info!
35+
history['ins'].append(obs)
36+
history['hx'].append(hx.squeeze(0).data.numpy())
37+
history['cx'].append(cx.squeeze(0).data.numpy())
38+
history['logits'].append(logit.data.numpy()[0])
39+
history['values'].append(value.data.numpy()[0])
40+
history['outs'].append(prob.data.numpy()[0])
41+
print('\tstep # {}, reward {:.0f}'.format(episode_length, epr), end='\r')
42+
43+
return history

saliency.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License
2+
3+
from __future__ import print_function
4+
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
5+
6+
import torch
7+
from torch.autograd import Variable
8+
import torch.nn.functional as F
9+
10+
import numpy as np
11+
from scipy.ndimage.filters import gaussian_filter
12+
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
13+
14+
prepro = lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.
15+
searchlight = lambda I, mask: I*mask + gaussian_filter(I, sigma=3)*(1-mask) # choose an area NOT to blur
16+
occlude = lambda I, mask: I*(1-mask) + gaussian_filter(I, sigma=3)*mask # choose an area to blur
17+
18+
def get_mask(center, size, r):
19+
y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]]
20+
keep = x*x + y*y <= 1
21+
mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels
22+
mask = gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1
23+
return mask/mask.max()
24+
25+
def run_through_model(model, history, ix, interp_func=None, mask=None, blur_memory=None, mode='actor'):
26+
if mask is None:
27+
im = prepro(history['ins'][ix])
28+
else:
29+
assert(interp_func is not None, "interp func cannot be none")
30+
im = interp_func(prepro(history['ins'][ix]).squeeze(), mask).reshape(1,80,80) # perturb input I -> I'
31+
tens_state = torch.Tensor(im)
32+
state = Variable(tens_state.unsqueeze(0), volatile=True)
33+
hx = Variable(torch.Tensor(history['hx'][ix-1]).view(1,-1))
34+
cx = Variable(torch.Tensor(history['cx'][ix-1]).view(1,-1))
35+
if blur_memory is not None: cx.mul_(1-blur_memory) # perturb memory vector
36+
return model((state, (hx, cx)))[0] if mode == 'critic' else model((state, (hx, cx)))[1]
37+
38+
def score_frame(model, history, ix, r, d, interp_func, mode='actor'):
39+
# r: radius of blur
40+
# d: density of scores (if d==1, then get a score for every pixel...
41+
# if d==2 then every other, which is 25% of total pixels for a 2D image)
42+
assert mode in ['actor', 'critic'], 'mode must be either "actor" or "critic"'
43+
L = run_through_model(model, history, ix, interp_func, mask=None, mode=mode)
44+
scores = np.zeros((int(80/d)+1,int(80/d)+1)) # saliency scores S(t,i,j)
45+
for i in range(0,80,d):
46+
for j in range(0,80,d):
47+
mask = get_mask(center=[i,j], size=[80,80], r=r)
48+
l = run_through_model(model, history, ix, interp_func, mask=mask, mode=mode)
49+
scores[int(i/d),int(j/d)] = (L-l).pow(2).sum().mul_(.5).data[0]
50+
pmax = scores.max()
51+
scores = imresize(scores, size=[80,80], interp='bilinear').astype(np.float32)
52+
return pmax * scores / scores.max()
53+
54+
def saliency_on_atari_frame(saliency, atari, fudge_factor, channel=2, sigma=0):
55+
# sometimes saliency maps are a bit clearer if you blur them
56+
# slightly...sigma adjusts the radius of that blur
57+
pmax = saliency.max()
58+
S = imresize(saliency, size=[160,160], interp='bilinear').astype(np.float32)
59+
S = S if sigma == 0 else gaussian_filter(S, sigma=sigma)
60+
S -= S.min() ; S = fudge_factor*pmax * S / S.max()
61+
I = atari.astype('uint16')
62+
I[35:195,:,channel] += S.astype('uint16')
63+
I = I.clip(1,255).astype('uint8')
64+
return I
65+
66+
def get_env_meta(env_name):
67+
meta = {}
68+
if env_name=="Pong-v0":
69+
meta['critic_ff'] = 600 ; meta['actor_ff'] = 500
70+
elif env_name=="Breakout-v0":
71+
meta['critic_ff'] = 600 ; meta['actor_ff'] = 300
72+
elif env_name=="SpaceInvaders-v0":
73+
meta['critic_ff'] = 400 ; meta['actor_ff'] = 400
74+
else:
75+
print('environment "{}" not supported'.format(env_name))
76+
return meta

0 commit comments

Comments
 (0)