Skip to content

Commit fcc132d

Browse files
Fix a bug for epsilon greedy policy with multiple workers
1 parent e71d7b1 commit fcc132d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deep_rl/utils/torch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def epsilon_greedy(epsilon, x):
4343
if len(x.shape) == 1:
4444
return np.random.randint(len(x)) if np.random.rand() < epsilon else np.argmax(x)
4545
elif len(x.shape) == 2:
46-
random_actions = np.random.randint(x.shape[1])
46+
random_actions = np.random.randint(x.shape[1], size=x.shape[0])
4747
greedy_actions = np.argmax(x, axis=-1)
4848
dice = np.random.rand(x.shape[0])
4949
return np.where(dice < epsilon, random_actions, greedy_actions)

0 commit comments

Comments
 (0)