@@ -356,19 +356,6 @@ def collect_leaves(module):
356
356
yield module
357
357
358
358
359
- class Identity (torch .autograd .Function ):
360
- '''Identity to add a grad_fn to a tensor, so a backward hook can be applied.'''
361
- @staticmethod
362
- def forward (ctx , * inputs ):
363
- '''Forward identity.'''
364
- return inputs
365
-
366
- @staticmethod
367
- def backward (ctx , * grad_outputs ):
368
- '''Backward identity.'''
369
- return grad_outputs
370
-
371
-
372
359
class Hook :
373
360
'''Base class for hooks to be used to compute layer-wise attributions.'''
374
361
def __init__ (self ):
@@ -381,29 +368,41 @@ def pre_forward(self, module, input):
381
368
hook_ref = weakref .ref (self )
382
369
383
370
@functools .wraps (self .backward )
384
- def wrapper (grad_input , grad_output ):
371
+ def wrapper (grad_input , grad_output , grad_sink ):
385
372
hook = hook_ref ()
386
373
if hook is not None and hook .active :
387
- return hook .backward (module , grad_input , hook .stored_tensors ['grad_output' ])
374
+ result = hook .backward (module , grad_output , hook .stored_tensors ['grad_output' ], grad_sink = grad_sink )
375
+ if not isinstance (result , tuple ):
376
+ result = (result ,)
377
+ if grad_input is None :
378
+ return result [0 ]
379
+ return result
388
380
return None
389
381
390
382
if not isinstance (input , tuple ):
391
383
input = (input ,)
392
384
393
- # only if gradient required
394
- if input [ 0 ]. requires_grad :
395
- # add identity to ensure .grad_fn exists
396
- post_input = Identity . apply ( * input )
385
+ post_input = tuple ( tensor . view_as ( tensor ) for tensor in input )
386
+
387
+ # hook required gradient sinks
388
+ for grad_sink , tensor in enumerate ( post_input ):
397
389
# register the input tensor gradient hook
398
- self .tensor_handles .append (
399
- post_input [0 ].grad_fn .register_hook (wrapper )
400
- )
401
- # work around to support in-place operations
402
- post_input = tuple (elem .clone () for elem in post_input )
403
- else :
404
- # no gradient required
405
- post_input = input
406
- return post_input [0 ] if len (post_input ) == 1 else post_input
390
+ if tensor .grad_fn is not None :
391
+ # grad_fn for inputs is here the view function applied above
392
+ self .tensor_handles .append (
393
+ tensor .grad_fn .register_hook (functools .partial (wrapper , grad_sink = grad_sink ))
394
+ )
395
+ # hook required gradient sinks
396
+ for grad_sink , tensor in module .named_parameters ():
397
+ if tensor .requires_grad :
398
+ # TODO: use grad_fn (need to store parameter views for the model...), otherwise the hook could be
399
+ # called for unrelated gradients
400
+ self .tensor_handles .append (
401
+ tensor .register_hook (functools .partial (wrapper , None , grad_sink = grad_sink ))
402
+ )
403
+
404
+ # torch.nn.Module converts single tensors to tuples anyway, so we can always return a tuple here
405
+ return post_input
407
406
408
407
def post_forward (self , module , input , output ):
409
408
'''Register a backward-hook to the resulting tensor right after the forward.'''
@@ -413,28 +412,28 @@ def post_forward(self, module, input, output):
413
412
def wrapper (grad_input , grad_output ):
414
413
hook = hook_ref ()
415
414
if hook is not None and hook .active :
416
- return hook .pre_backward (module , grad_input , grad_output )
415
+ return hook .pre_backward (module , grad_output )
417
416
return None
418
417
419
- if not isinstance (output , tuple ):
420
- output = (output ,)
418
+ hookable_output = output
419
+ if not isinstance (hookable_output , tuple ):
420
+ hookable_output = (hookable_output ,)
421
421
422
422
# only if gradient required
423
- if output [0 ].grad_fn is not None :
423
+ if hookable_output [0 ].requires_grad :
424
424
# register the output tensor gradient hook
425
425
self .tensor_handles .append (
426
- output [0 ].grad_fn .register_hook (wrapper )
426
+ hookable_output [0 ].grad_fn .register_hook (wrapper )
427
427
)
428
- return output [0 ] if len (output ) == 1 else output
429
428
430
- def pre_backward (self , module , grad_input , grad_output ):
429
+ def pre_backward (self , module , grad_output ):
431
430
'''Store the grad_output for the backward hook'''
432
431
self .stored_tensors ['grad_output' ] = grad_output
433
432
434
433
def forward (self , module , input , output ):
435
434
'''Hook applied during forward-pass'''
436
435
437
- def backward (self , module , grad_input , grad_output ):
436
+ def backward (self , module , grad_input , grad_output , grad_sink ):
438
437
'''Hook applied during backward-pass'''
439
438
440
439
def copy (self ):
@@ -522,18 +521,18 @@ def forward(self, module, input, output):
522
521
'''Forward hook to save module in-/outputs.'''
523
522
self .stored_tensors ['input' ] = input
524
523
525
- def backward (self , module , grad_input , grad_output ):
524
+ def backward (self , module , grad_input , grad_output , grad_sink ):
526
525
'''Backward hook to compute LRP based on the class attributes.'''
527
- original_input = self .stored_tensors ['input' ][ 0 ]. clone ()
526
+ original_inputs = [ tensor . view_as ( tensor ) for tensor in self .stored_tensors ['input' ]]
528
527
inputs = []
529
528
outputs = []
530
529
for in_mod , param_mod , out_mod in zip (self .input_modifiers , self .param_modifiers , self .output_modifiers ):
531
- input = in_mod (original_input ).requires_grad_ ()
530
+ input_args = [ in_mod (tensor ).requires_grad_ () for tensor in original_inputs ]
532
531
with ParamMod .ensure (param_mod )(module ) as modified , torch .autograd .enable_grad ():
533
- output = modified .forward (input )
534
- output = out_mod ( output )
535
- inputs .append (input )
536
- outputs .append (output )
532
+ output = modified .forward (* input_args )
533
+ # decide for which argument to compute the relevance
534
+ inputs .append (input_args [ grad_sink ] if isinstance ( grad_sink , int ) else getattr ( modified , grad_sink ) )
535
+ outputs .append (out_mod ( output ) )
537
536
grad_outputs = self .gradient_mapper (grad_output [0 ], outputs )
538
537
gradients = torch .autograd .grad (
539
538
outputs ,
@@ -542,7 +541,7 @@ def backward(self, module, grad_input, grad_output):
542
541
create_graph = grad_output [0 ].requires_grad
543
542
)
544
543
relevance = self .reducer (inputs , gradients )
545
- return tuple ( relevance if original . shape == relevance . shape else None for original in grad_input )
544
+ return relevance
546
545
547
546
def copy (self ):
548
547
'''Return a copy of this hook.
0 commit comments