19
19
import functools
20
20
import weakref
21
21
from contextlib import contextmanager
22
+ from itertools import compress , repeat , islice , chain
23
+ from inspect import signature
22
24
23
25
import torch
24
26
@@ -234,6 +236,43 @@ def modifier_wrapper(input, name):
234
236
return zero_params_wrapper
235
237
236
238
239
+ def uncompress (data , selector , compressed ):
240
+ '''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
241
+ to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
242
+ `data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
243
+ `compressed`, while `False` values yield one value from `data`.
244
+
245
+ Parameters
246
+ ----------
247
+ data : iterable
248
+ The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
249
+ this iterator, while `True` values will cause this iterable to be skipped.
250
+ selector : iterable of bool
251
+ The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
252
+ will be yielded.
253
+ compressed : iterable
254
+ The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
255
+
256
+ Yields
257
+ ------
258
+ object
259
+ An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
260
+ while skipping `data` one ahead.
261
+
262
+ '''
263
+ its = iter (selector )
264
+ itc = iter (compressed )
265
+ itd = iter (data )
266
+ try :
267
+ if next (its ):
268
+ next (itd )
269
+ yield next (itc )
270
+ else :
271
+ yield next (itd )
272
+ except StopIteration :
273
+ return
274
+
275
+
237
276
class ParamMod :
238
277
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
239
278
@@ -360,6 +399,7 @@ class Identity(torch.autograd.Function):
360
399
@staticmethod
361
400
def forward (ctx , * inputs ):
362
401
'''Forward identity.'''
402
+ ctx .mark_non_differentiable (* [elem for elem in inputs if not elem .requires_grad ])
363
403
return inputs
364
404
365
405
@staticmethod
@@ -375,62 +415,94 @@ def __init__(self):
375
415
self .active = True
376
416
self .tensor_handles = RemovableHandleList ()
377
417
378
- def pre_forward (self , module , input ):
418
+ @staticmethod
419
+ def _inject_grad_fn (args ):
420
+ tensor_mask = tuple (isinstance (elem , torch .Tensor ) for elem in args )
421
+ tensors = tuple (compress (args , tensor_mask ))
422
+ # tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
423
+
424
+ # only if gradient required
425
+ if not any (tensor .requires_grad for tensor in tensors ):
426
+ return None , args , tensor_mask
427
+
428
+ # add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
429
+ post_tensors = Identity .apply (* tensors )
430
+ grad_fn = next ((tensor .grad_fn for tensor in post_tensors if tensor .grad_fn is not None ), None )
431
+ if grad_fn is None :
432
+ raise RuntimeError ('Backward hook could not be registered!' )
433
+
434
+ # work-around to support in-place operations
435
+ post_tensors = tuple (elem .clone () for elem in post_tensors )
436
+ post_args = tuple (uncompress (args , tensor_mask , post_tensors ))
437
+ return grad_fn , post_args , tensor_mask
438
+
439
+ def pre_forward (self , module , args , kwargs ):
379
440
'''Apply an Identity to the input before the module to register a backward hook.'''
380
441
hook_ref = weakref .ref (self )
381
442
443
+ grad_fn , post_args , input_tensor_mask = self ._inject_grad_fn (args )
444
+ if grad_fn is None :
445
+ return
446
+
382
447
@functools .wraps (self .backward )
383
448
def wrapper (grad_input , grad_output ):
384
449
hook = hook_ref ()
385
450
if hook is not None and hook .active :
386
- return hook .backward (module , grad_input , hook .stored_tensors ['grad_output' ])
451
+ return hook .backward (
452
+ module ,
453
+ list (uncompress (
454
+ repeat (None ),
455
+ input_tensor_mask ,
456
+ grad_input ,
457
+ )),
458
+ hook .stored_tensors ['grad_output' ],
459
+ )
387
460
return None
388
461
389
- if not isinstance ( input , tuple ):
390
- input = ( input , )
462
+ # register the input tensor gradient hook
463
+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
391
464
392
- # only if gradient required
393
- if input [0 ].requires_grad :
394
- # add identity to ensure .grad_fn exists
395
- post_input = Identity .apply (* input )
396
- # register the input tensor gradient hook
397
- self .tensor_handles .append (
398
- post_input [0 ].grad_fn .register_hook (wrapper )
399
- )
400
- # work around to support in-place operations
401
- post_input = tuple (elem .clone () for elem in post_input )
402
- else :
403
- # no gradient required
404
- post_input = input
405
- return post_input [0 ] if len (post_input ) == 1 else post_input
465
+ return post_args , kwargs
406
466
407
- def post_forward (self , module , input , output ):
467
+ def post_forward (self , module , args , kwargs , output ):
408
468
'''Register a backward-hook to the resulting tensor right after the forward.'''
409
469
hook_ref = weakref .ref (self )
410
470
471
+ single = not isinstance (output , tuple )
472
+ if single :
473
+ output = (output ,)
474
+
475
+ grad_fn , post_output , output_tensor_mask = self ._inject_grad_fn (output )
476
+ if grad_fn is None :
477
+ return
478
+
411
479
@functools .wraps (self .pre_backward )
412
480
def wrapper (grad_input , grad_output ):
413
481
hook = hook_ref ()
414
482
if hook is not None and hook .active :
415
- return hook .pre_backward (module , grad_input , grad_output )
483
+ return hook .pre_backward (
484
+ module ,
485
+ grad_input ,
486
+ tuple (uncompress (
487
+ repeat (None ),
488
+ output_tensor_mask ,
489
+ grad_output
490
+ ))
491
+ )
416
492
return None
417
493
418
- if not isinstance ( output , tuple ):
419
- output = ( output , )
494
+ # register the output tensor gradient hook
495
+ self . tensor_handles . append ( grad_fn . register_hook ( wrapper ) )
420
496
421
- # only if gradient required
422
- if output [0 ].grad_fn is not None :
423
- # register the output tensor gradient hook
424
- self .tensor_handles .append (
425
- output [0 ].grad_fn .register_hook (wrapper )
426
- )
427
- return output [0 ] if len (output ) == 1 else output
497
+ if single :
498
+ return post_output [0 ]
499
+ return post_output
428
500
429
501
def pre_backward (self , module , grad_input , grad_output ):
430
502
'''Store the grad_output for the backward hook'''
431
503
self .stored_tensors ['grad_output' ] = grad_output
432
504
433
- def forward (self , module , input , output ):
505
+ def forward (self , module , args , kwargs , output ):
434
506
'''Hook applied during forward-pass'''
435
507
436
508
def backward (self , module , grad_input , grad_output ):
@@ -449,11 +521,14 @@ def remove(self):
449
521
450
522
def register (self , module ):
451
523
'''Register this instance by registering all hooks to the supplied module.'''
524
+ # assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
525
+ forward_params = signature (self .forward ).parameters
526
+ with_kwargs = len (forward_params ) != 3 and list (forward_params )[2 ] != 'output'
452
527
return RemovableHandleList ([
453
528
RemovableHandle (self ),
454
- module .register_forward_pre_hook (self .pre_forward ),
455
- module .register_forward_hook (self .post_forward ),
456
- module .register_forward_hook (self .forward ),
529
+ module .register_forward_pre_hook (self .pre_forward , with_kwargs = True ),
530
+ module .register_forward_hook (self .post_forward , with_kwargs = True ),
531
+ module .register_forward_hook (self .forward , with_kwargs = with_kwargs ),
457
532
])
458
533
459
534
@@ -517,31 +592,61 @@ def __init__(
517
592
self .gradient_mapper = gradient_mapper
518
593
self .reducer = reducer
519
594
520
- def forward (self , module , input , output ):
595
+ def forward (self , module , args , kwargs , output ):
521
596
'''Forward hook to save module in-/outputs.'''
522
- self .stored_tensors ['input' ] = input
597
+ self .stored_tensors ['input' ] = args
598
+ self .stored_tensors ['kwargs' ] = kwargs
523
599
524
600
def backward (self , module , grad_input , grad_output ):
525
601
'''Backward hook to compute LRP based on the class attributes.'''
526
- original_input = self .stored_tensors ['input' ][0 ].clone ()
602
+ input_mask = [elem is not None for elem in self .stored_tensors ['input' ]]
603
+ output_mask = [elem is not None for elem in grad_output ]
604
+ cgrad_output = tuple (compress (grad_output , output_mask ))
605
+
606
+ original_inputs = [tensor .clone () for tensor in self .stored_tensors ['input' ]]
607
+ kwargs = self .stored_tensors ['kwargs' ]
527
608
inputs = []
528
609
outputs = []
529
610
for in_mod , param_mod , out_mod in zip (self .input_modifiers , self .param_modifiers , self .output_modifiers ):
530
- input = in_mod (original_input ).requires_grad_ ()
611
+ mod_args = (in_mod (tensor ).requires_grad_ () for tensor in compress (original_inputs , input_mask ))
612
+ args = tuple (uncompress (original_inputs , input_mask , mod_args ))
531
613
with ParamMod .ensure (param_mod )(module ) as modified , torch .autograd .enable_grad ():
532
- output = modified .forward (input )
533
- output = out_mod (output )
534
- inputs .append (input )
614
+ output = modified .forward (* args , ** kwargs )
615
+ if not isinstance (output , tuple ):
616
+ output = (output ,)
617
+ output = tuple (out_mod (tensor ) for tensor in compress (output , output_mask ))
618
+ inputs .append (compress (args , input_mask ))
535
619
outputs .append (output )
536
- grad_outputs = self .gradient_mapper (grad_output [0 ], outputs )
537
- gradients = torch .autograd .grad (
538
- outputs ,
539
- inputs ,
620
+
621
+ inputs = list (zip (* inputs ))
622
+ outputs = list (zip (* outputs ))
623
+ input_struct = [len (elem ) for elem in inputs ]
624
+ output_struct = [len (elem ) for elem in outputs ]
625
+
626
+ grad_outputs = tuple (
627
+ self .gradient_mapper (gradout , outs )
628
+ for gradout , outs in zip (cgrad_output , outputs )
629
+ )
630
+ inputs_flat = tuple (chain .from_iterable (inputs ))
631
+ outputs_flat = tuple (chain .from_iterable (outputs ))
632
+ if not all (isinstance (elem , torch .Tensor ) for elem in grad_outputs ):
633
+ # if there is only a single output modifier, grad_outputs may contain tensors
634
+ grad_outputs = tuple (chain .from_iterable (grad_outputs ))
635
+
636
+ gradients_flat = torch .autograd .grad (
637
+ outputs_flat ,
638
+ inputs_flat ,
540
639
grad_outputs = grad_outputs ,
541
- create_graph = grad_output [ 0 ] .requires_grad
640
+ create_graph = any ( tensor .requires_grad for tensor in cgrad_output )
542
641
)
543
- relevance = self .reducer (inputs , gradients )
544
- return tuple (relevance if original .shape == relevance .shape else None for original in grad_input )
642
+
643
+ # input_it = iter(inputs)
644
+ # inputs_re = [tuple(islice(input_it, size)) for size in input_struct]
645
+ gradient_it = iter (gradients_flat )
646
+ gradients = [tuple (islice (gradient_it , size )) for size in input_struct ]
647
+
648
+ relevances = (self .reducer (inp , grad ) for inp , grad in zip (inputs , gradients ))
649
+ return tuple (uncompress (repeat (None ), input_mask , relevances ))
545
650
546
651
def copy (self ):
547
652
'''Return a copy of this hook.
0 commit comments