Skip to content

Commit 2959775

Browse files
authored
correctly pass label smoothing from the arguments
1 parent bcb1e23 commit 2959775

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

example/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@
5050
# first forward-backward step
5151
enable_running_stats(model)
5252
predictions = model(inputs)
53-
loss = smooth_crossentropy(predictions, targets)
53+
loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
5454
loss.mean().backward()
5555
optimizer.first_step(zero_grad=True)
5656

5757
# second forward-backward step
5858
disable_running_stats(model)
59-
smooth_crossentropy(model(inputs), targets).mean().backward()
59+
smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
6060
optimizer.second_step(zero_grad=True)
6161

6262
with torch.no_grad():

0 commit comments

Comments
 (0)