8
8
from ..component import *
9
9
from ..utils import *
10
10
from .BaseAgent import *
11
+ from .DQN_agent import *
11
12
12
13
13
- class QuantileRegressionDQNActor (BaseActor ):
14
+ class QuantileRegressionDQNActor (DQNActor ):
14
15
def __init__ (self , config ):
15
- BaseActor .__init__ (self , config )
16
- self .config = config
17
- self .start ()
18
-
19
- def _transition (self ):
20
- if self ._state is None :
21
- self ._state = self ._task .reset ()
22
- config = self .config
23
- with config .lock :
24
- q_values = self ._network (config .state_normalizer (self ._state )).mean (- 1 )
25
- q_values = to_np (q_values ).flatten ()
26
- if self ._total_steps < config .exploration_steps \
27
- or np .random .rand () < config .random_action_prob ():
28
- action = np .random .randint (0 , len (q_values ))
29
- else :
30
- action = np .argmax (q_values )
31
- next_state , reward , done , info = self ._task .step ([action ])
32
- entry = [self ._state [0 ], action , reward [0 ], next_state [0 ], int (done [0 ]), info ]
33
- self ._total_steps += 1
34
- self ._state = next_state
35
- return entry
36
-
37
-
38
- class QuantileRegressionDQNAgent (BaseAgent ):
16
+ super ().__init__ (config )
17
+
18
+ def compute_q (self , prediction ):
19
+ q_values = prediction ['quantile' ].mean (- 1 )
20
+ return to_np (q_values )
21
+
22
+
23
+ class QuantileRegressionDQNAgent (DQNAgent ):
39
24
def __init__ (self , config ):
40
25
BaseAgent .__init__ (self , config )
41
26
self .config = config
@@ -53,63 +38,40 @@ def __init__(self, config):
53
38
self .actor .set_network (self .network )
54
39
55
40
self .total_steps = 0
56
- self .batch_indices = range_tensor (self . replay .batch_size )
41
+ self .batch_indices = range_tensor (config .batch_size )
57
42
58
43
self .quantile_weight = 1.0 / self .config .num_quantiles
59
44
self .cumulative_density = tensor (
60
45
(2 * np .arange (self .config .num_quantiles ) + 1 ) / (2.0 * self .config .num_quantiles )).view (1 , - 1 )
61
46
62
- def close (self ):
63
- close_obj (self .replay )
64
- close_obj (self .actor )
65
-
66
47
def eval_step (self , state ):
67
48
self .config .state_normalizer .set_read_only ()
68
49
state = self .config .state_normalizer (state )
69
- q = self .network (state ).mean (- 1 )
50
+ q = self .network (state )[ 'quantile' ] .mean (- 1 )
70
51
action = np .argmax (to_np (q ).flatten ())
71
52
self .config .state_normalizer .unset_read_only ()
72
53
return [action ]
73
54
74
- def step (self ):
75
- config = self .config
76
- transitions = self .actor .step ()
77
- experiences = []
78
- for state , action , reward , next_state , done , info in transitions :
79
- self .record_online_return (info )
80
- self .total_steps += 1
81
- reward = config .reward_normalizer (reward )
82
- experiences .append ([state , action , reward , next_state , done ])
83
- self .replay .feed_batch (experiences )
84
-
85
- if self .total_steps > self .config .exploration_steps :
86
- experiences = self .replay .sample ()
87
- states , actions , rewards , next_states , terminals = experiences
88
- states = self .config .state_normalizer (states )
89
- next_states = self .config .state_normalizer (next_states )
90
-
91
- quantiles_next = self .target_network (next_states ).detach ()
92
- a_next = torch .argmax (quantiles_next .sum (- 1 ), dim = - 1 )
93
- quantiles_next = quantiles_next [self .batch_indices , a_next , :]
94
-
95
- rewards = tensor (rewards ).unsqueeze (- 1 )
96
- terminals = tensor (terminals ).unsqueeze (- 1 )
97
- quantiles_next = rewards + self .config .discount * (1 - terminals ) * quantiles_next
98
-
99
- quantiles = self .network (states )
100
- actions = tensor (actions ).long ()
101
- quantiles = quantiles [self .batch_indices , actions , :]
102
-
103
- quantiles_next = quantiles_next .t ().unsqueeze (- 1 )
104
- diff = quantiles_next - quantiles
105
- loss = huber (diff ) * (self .cumulative_density - (diff .detach () < 0 ).float ()).abs ()
106
-
107
- self .optimizer .zero_grad ()
108
- loss .mean (0 ).mean (1 ).sum ().backward ()
109
- nn .utils .clip_grad_norm_ (self .network .parameters (), self .config .gradient_clip )
110
- with config .lock :
111
- self .optimizer .step ()
112
-
113
- if self .total_steps / self .config .sgd_update_frequency % \
114
- self .config .target_network_update_freq == 0 :
115
- self .target_network .load_state_dict (self .network .state_dict ())
55
+ def compute_loss (self , transitions ):
56
+ states = self .config .state_normalizer (transitions .state )
57
+ next_states = self .config .state_normalizer (transitions .next_state )
58
+
59
+ quantiles_next = self .target_network (next_states )['quantile' ].detach ()
60
+ a_next = torch .argmax (quantiles_next .sum (- 1 ), dim = - 1 )
61
+ quantiles_next = quantiles_next [self .batch_indices , a_next , :]
62
+
63
+ rewards = tensor (transitions .reward ).unsqueeze (- 1 )
64
+ masks = tensor (transitions .mask ).unsqueeze (- 1 )
65
+ quantiles_next = rewards + self .config .discount * masks * quantiles_next
66
+
67
+ quantiles = self .network (states )['quantile' ]
68
+ actions = tensor (transitions .action ).long ()
69
+ quantiles = quantiles [self .batch_indices , actions , :]
70
+
71
+ quantiles_next = quantiles_next .t ().unsqueeze (- 1 )
72
+ diff = quantiles_next - quantiles
73
+ loss = huber (diff ) * (self .cumulative_density - (diff .detach () < 0 ).float ()).abs ()
74
+ return loss .sum (- 1 ).mean (1 )
75
+
76
+ def reduce_loss (self , loss ):
77
+ return loss .mean ()
0 commit comments