Skip to content

Commit 26c07ad

Browse files
Update QRDQN and C51
1 parent 157348b commit 26c07ad

File tree

6 files changed

+64
-240
lines changed

6 files changed

+64
-240
lines changed

deep_rl/agent/QuantileRegressionDQN_agent.py

Lines changed: 35 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,19 @@
88
from ..component import *
99
from ..utils import *
1010
from .BaseAgent import *
11+
from .DQN_agent import *
1112

1213

13-
class QuantileRegressionDQNActor(BaseActor):
14+
class QuantileRegressionDQNActor(DQNActor):
1415
def __init__(self, config):
15-
BaseActor.__init__(self, config)
16-
self.config = config
17-
self.start()
18-
19-
def _transition(self):
20-
if self._state is None:
21-
self._state = self._task.reset()
22-
config = self.config
23-
with config.lock:
24-
q_values = self._network(config.state_normalizer(self._state)).mean(-1)
25-
q_values = to_np(q_values).flatten()
26-
if self._total_steps < config.exploration_steps \
27-
or np.random.rand() < config.random_action_prob():
28-
action = np.random.randint(0, len(q_values))
29-
else:
30-
action = np.argmax(q_values)
31-
next_state, reward, done, info = self._task.step([action])
32-
entry = [self._state[0], action, reward[0], next_state[0], int(done[0]), info]
33-
self._total_steps += 1
34-
self._state = next_state
35-
return entry
36-
37-
38-
class QuantileRegressionDQNAgent(BaseAgent):
16+
super().__init__(config)
17+
18+
def compute_q(self, prediction):
19+
q_values = prediction['quantile'].mean(-1)
20+
return to_np(q_values)
21+
22+
23+
class QuantileRegressionDQNAgent(DQNAgent):
3924
def __init__(self, config):
4025
BaseAgent.__init__(self, config)
4126
self.config = config
@@ -53,63 +38,40 @@ def __init__(self, config):
5338
self.actor.set_network(self.network)
5439

5540
self.total_steps = 0
56-
self.batch_indices = range_tensor(self.replay.batch_size)
41+
self.batch_indices = range_tensor(config.batch_size)
5742

5843
self.quantile_weight = 1.0 / self.config.num_quantiles
5944
self.cumulative_density = tensor(
6045
(2 * np.arange(self.config.num_quantiles) + 1) / (2.0 * self.config.num_quantiles)).view(1, -1)
6146

62-
def close(self):
63-
close_obj(self.replay)
64-
close_obj(self.actor)
65-
6647
def eval_step(self, state):
6748
self.config.state_normalizer.set_read_only()
6849
state = self.config.state_normalizer(state)
69-
q = self.network(state).mean(-1)
50+
q = self.network(state)['quantile'].mean(-1)
7051
action = np.argmax(to_np(q).flatten())
7152
self.config.state_normalizer.unset_read_only()
7253
return [action]
7354

74-
def step(self):
75-
config = self.config
76-
transitions = self.actor.step()
77-
experiences = []
78-
for state, action, reward, next_state, done, info in transitions:
79-
self.record_online_return(info)
80-
self.total_steps += 1
81-
reward = config.reward_normalizer(reward)
82-
experiences.append([state, action, reward, next_state, done])
83-
self.replay.feed_batch(experiences)
84-
85-
if self.total_steps > self.config.exploration_steps:
86-
experiences = self.replay.sample()
87-
states, actions, rewards, next_states, terminals = experiences
88-
states = self.config.state_normalizer(states)
89-
next_states = self.config.state_normalizer(next_states)
90-
91-
quantiles_next = self.target_network(next_states).detach()
92-
a_next = torch.argmax(quantiles_next.sum(-1), dim=-1)
93-
quantiles_next = quantiles_next[self.batch_indices, a_next, :]
94-
95-
rewards = tensor(rewards).unsqueeze(-1)
96-
terminals = tensor(terminals).unsqueeze(-1)
97-
quantiles_next = rewards + self.config.discount * (1 - terminals) * quantiles_next
98-
99-
quantiles = self.network(states)
100-
actions = tensor(actions).long()
101-
quantiles = quantiles[self.batch_indices, actions, :]
102-
103-
quantiles_next = quantiles_next.t().unsqueeze(-1)
104-
diff = quantiles_next - quantiles
105-
loss = huber(diff) * (self.cumulative_density - (diff.detach() < 0).float()).abs()
106-
107-
self.optimizer.zero_grad()
108-
loss.mean(0).mean(1).sum().backward()
109-
nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip)
110-
with config.lock:
111-
self.optimizer.step()
112-
113-
if self.total_steps / self.config.sgd_update_frequency % \
114-
self.config.target_network_update_freq == 0:
115-
self.target_network.load_state_dict(self.network.state_dict())
55+
def compute_loss(self, transitions):
56+
states = self.config.state_normalizer(transitions.state)
57+
next_states = self.config.state_normalizer(transitions.next_state)
58+
59+
quantiles_next = self.target_network(next_states)['quantile'].detach()
60+
a_next = torch.argmax(quantiles_next.sum(-1), dim=-1)
61+
quantiles_next = quantiles_next[self.batch_indices, a_next, :]
62+
63+
rewards = tensor(transitions.reward).unsqueeze(-1)
64+
masks = tensor(transitions.mask).unsqueeze(-1)
65+
quantiles_next = rewards + self.config.discount * masks * quantiles_next
66+
67+
quantiles = self.network(states)['quantile']
68+
actions = tensor(transitions.action).long()
69+
quantiles = quantiles[self.batch_indices, actions, :]
70+
71+
quantiles_next = quantiles_next.t().unsqueeze(-1)
72+
diff = quantiles_next - quantiles
73+
loss = huber(diff) * (self.cumulative_density - (diff.detach() < 0).float()).abs()
74+
return loss.sum(-1).mean(1)
75+
76+
def reduce_loss(self, loss):
77+
return loss.mean()

deep_rl/agent/Rainbow_agent.py

Lines changed: 0 additions & 155 deletions
This file was deleted.

deep_rl/agent/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@
77
from .PPO_agent import *
88
from .OptionCritic_agent import *
99
from .TD3_agent import *
10-
from .Rainbow_agent import *

deep_rl/network/network_heads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def forward(self, x):
5151
pre_prob = self.fc_categorical(phi).view((-1, self.action_dim, self.num_atoms))
5252
prob = F.softmax(pre_prob, dim=-1)
5353
log_prob = F.log_softmax(pre_prob, dim=-1)
54-
return prob, log_prob
54+
return dict(prob=prob, log_prob=log_prob)
5555

5656

5757
class RainbowNet(nn.Module, BaseNet):
@@ -99,7 +99,7 @@ def forward(self, x):
9999
phi = self.body(tensor(x))
100100
quantiles = self.fc_quantiles(phi)
101101
quantiles = quantiles.view((-1, self.action_dim, self.num_quantiles))
102-
return quantiles
102+
return dict(quantile=quantiles)
103103

104104

105105
class OptionCriticNet(nn.Module, BaseNet):

deep_rl/utils/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self):
6565
self.decaying_lr = False
6666
self.shared_repr = False
6767
self.noisy_linear = False
68+
self.n_step = 1
6869

6970
@property
7071
def eval_env(self):

examples.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,11 @@ def quantile_regression_dqn_feature(**kwargs):
108108
config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001)
109109
config.network_fn = lambda: QuantileNet(config.action_dim, config.num_quantiles, FCBody(config.state_dim))
110110

111-
# config.replay_fn = lambda: Replay(memory_size=int(1e4), batch_size=10)
112-
config.replay_fn = lambda: AsyncReplay(memory_size=int(1e4), batch_size=10)
111+
config.batch_size = 10
112+
replay_kwargs = dict(
113+
memory_size=int(1e4),
114+
batch_size=config.batch_size)
115+
config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True)
113116

114117
config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4)
115118
config.discount = 0.99
@@ -136,8 +139,13 @@ def quantile_regression_dqn_pixel(**kwargs):
136139
config.network_fn = lambda: QuantileNet(config.action_dim, config.num_quantiles, NatureConvBody())
137140
config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6)
138141

139-
# config.replay_fn = lambda: Replay(memory_size=int(1e6), batch_size=32)
140-
config.replay_fn = lambda: AsyncReplay(memory_size=int(1e6), batch_size=32)
142+
config.batch_size = 32
143+
replay_kwargs = dict(
144+
memory_size=int(1e6),
145+
batch_size=config.batch_size,
146+
history_length=4,
147+
)
148+
config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True)
141149

142150
config.state_normalizer = ImageNormalizer()
143151
config.reward_normalizer = SignNormalizer()
@@ -164,8 +172,11 @@ def categorical_dqn_feature(**kwargs):
164172
config.network_fn = lambda: CategoricalNet(config.action_dim, config.categorical_n_atoms, FCBody(config.state_dim))
165173
config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4)
166174

167-
# config.replay_fn = lambda: Replay(memory_size=10000, batch_size=10)
168-
config.replay_fn = lambda: AsyncReplay(memory_size=10000, batch_size=10)
175+
config.batch_size = 10
176+
replay_kwargs = dict(
177+
memory_size=int(1e4),
178+
batch_size=config.batch_size)
179+
config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True)
169180

170181
config.discount = 0.99
171182
config.target_network_update_freq = 200
@@ -193,8 +204,13 @@ def categorical_dqn_pixel(**kwargs):
193204
config.network_fn = lambda: CategoricalNet(config.action_dim, config.categorical_n_atoms, NatureConvBody())
194205
config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6)
195206

196-
# config.replay_fn = lambda: Replay(memory_size=int(1e6), batch_size=32)
197-
config.replay_fn = lambda: AsyncReplay(memory_size=int(1e6), batch_size=32)
207+
config.batch_size = 32
208+
replay_kwargs = dict(
209+
memory_size=int(1e6),
210+
batch_size=config.batch_size,
211+
history_length=4,
212+
)
213+
config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True)
198214

199215
config.discount = 0.99
200216
config.state_normalizer = ImageNormalizer()
@@ -605,6 +621,7 @@ def td3_continuous(**kwargs):
605621
mkdir('tf_log')
606622
set_one_thread()
607623
random_seed()
624+
# -1 is CPU, a positive integer is the index of GPU
608625
select_device(-1)
609626
# select_device(0)
610627

@@ -627,7 +644,7 @@ def td3_continuous(**kwargs):
627644
game = 'BreakoutNoFrameskip-v4'
628645
# dqn_pixel(game=game, n_step=1, replay_cls=UniformReplay, async_replay=True)
629646
# quantile_regression_dqn_pixel(game=game)
630-
# categorical_dqn_pixel(game=game)
647+
categorical_dqn_pixel(game=game)
631648
# rainbow_pixel(game=game)
632649
# a2c_pixel(game=game)
633650
# n_step_dqn_pixel(game=game)

0 commit comments

Comments
 (0)