We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e71d7b1 commit fcc132dCopy full SHA for fcc132d
deep_rl/utils/torch_utils.py
@@ -43,7 +43,7 @@ def epsilon_greedy(epsilon, x):
43
if len(x.shape) == 1:
44
return np.random.randint(len(x)) if np.random.rand() < epsilon else np.argmax(x)
45
elif len(x.shape) == 2:
46
- random_actions = np.random.randint(x.shape[1])
+ random_actions = np.random.randint(x.shape[1], size=x.shape[0])
47
greedy_actions = np.argmax(x, axis=-1)
48
dice = np.random.rand(x.shape[0])
49
return np.where(dice < epsilon, random_actions, greedy_actions)
0 commit comments