Skip to content

Commit d968940

Browse files
committed
Core: Multiple Inputs and Keyword Arguments
- use additions to forward hooks in torch 2.0.0 to pass kwargs to pass keyword arguments - handle multiple inputs and outputs in core.Hook and core.BasicHook, by passing all required grad_outputs and inputs to the backward implementation TODO: - attribution scores are currently wrong in BasicHook, likely an issue with the gradient inside BasicHook? Might be some cross-terms interacting that should not interact - finish draft and test implementation - add tests - add documentation - This stands in conflict with #168, but promises a better implementation by handling inputs and outpus as common to a single function, rather than individually as proposed in #168
1 parent 065c821 commit d968940

File tree

2 files changed

+153
-48
lines changed

2 files changed

+153
-48
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def replace(mobj):
5454
'click',
5555
'numpy',
5656
'Pillow',
57-
'torch>=1.7.0',
57+
'torch>=2.0.0',
5858
'torchvision',
5959
],
6060
setup_requires=[

src/zennit/core.py

Lines changed: 152 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import functools
2020
import weakref
2121
from contextlib import contextmanager
22+
from itertools import compress, repeat, islice, chain
23+
from inspect import signature
2224

2325
import torch
2426

@@ -234,6 +236,43 @@ def modifier_wrapper(input, name):
234236
return zero_params_wrapper
235237

236238

239+
def uncompress(data, selector, compressed):
240+
'''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
241+
to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
242+
`data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
243+
`compressed`, while `False` values yield one value from `data`.
244+
245+
Parameters
246+
----------
247+
data : iterable
248+
The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
249+
this iterator, while `True` values will cause this iterable to be skipped.
250+
selector : iterable of bool
251+
The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
252+
will be yielded.
253+
compressed : iterable
254+
The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
255+
256+
Yields
257+
------
258+
object
259+
An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
260+
while skipping `data` one ahead.
261+
262+
'''
263+
its = iter(selector)
264+
itc = iter(compressed)
265+
itd = iter(data)
266+
try:
267+
if next(its):
268+
next(itd)
269+
yield next(itc)
270+
else:
271+
yield next(itd)
272+
except StopIteration:
273+
return
274+
275+
237276
class ParamMod:
238277
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
239278
@@ -360,6 +399,7 @@ class Identity(torch.autograd.Function):
360399
@staticmethod
361400
def forward(ctx, *inputs):
362401
'''Forward identity.'''
402+
ctx.mark_non_differentiable(*[elem for elem in inputs if not elem.requires_grad])
363403
return inputs
364404

365405
@staticmethod
@@ -375,62 +415,94 @@ def __init__(self):
375415
self.active = True
376416
self.tensor_handles = RemovableHandleList()
377417

378-
def pre_forward(self, module, input):
418+
@staticmethod
419+
def _inject_grad_fn(args):
420+
tensor_mask = tuple(isinstance(elem, torch.Tensor) for elem in args)
421+
tensors = tuple(compress(args, tensor_mask))
422+
# tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]
423+
424+
# only if gradient required
425+
if not any(tensor.requires_grad for tensor in tensors):
426+
return None, args, tensor_mask
427+
428+
# add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
429+
post_tensors = Identity.apply(*tensors)
430+
grad_fn = next((tensor.grad_fn for tensor in post_tensors if tensor.grad_fn is not None), None)
431+
if grad_fn is None:
432+
raise RuntimeError('Backward hook could not be registered!')
433+
434+
# work-around to support in-place operations
435+
post_tensors = tuple(elem.clone() for elem in post_tensors)
436+
post_args = tuple(uncompress(args, tensor_mask, post_tensors))
437+
return grad_fn, post_args, tensor_mask
438+
439+
def pre_forward(self, module, args, kwargs):
379440
'''Apply an Identity to the input before the module to register a backward hook.'''
380441
hook_ref = weakref.ref(self)
381442

443+
grad_fn, post_args, input_tensor_mask = self._inject_grad_fn(args)
444+
if grad_fn is None:
445+
return
446+
382447
@functools.wraps(self.backward)
383448
def wrapper(grad_input, grad_output):
384449
hook = hook_ref()
385450
if hook is not None and hook.active:
386-
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
451+
return hook.backward(
452+
module,
453+
list(uncompress(
454+
repeat(None),
455+
input_tensor_mask,
456+
grad_input,
457+
)),
458+
hook.stored_tensors['grad_output'],
459+
)
387460
return None
388461

389-
if not isinstance(input, tuple):
390-
input = (input,)
462+
# register the input tensor gradient hook
463+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
391464

392-
# only if gradient required
393-
if input[0].requires_grad:
394-
# add identity to ensure .grad_fn exists
395-
post_input = Identity.apply(*input)
396-
# register the input tensor gradient hook
397-
self.tensor_handles.append(
398-
post_input[0].grad_fn.register_hook(wrapper)
399-
)
400-
# work around to support in-place operations
401-
post_input = tuple(elem.clone() for elem in post_input)
402-
else:
403-
# no gradient required
404-
post_input = input
405-
return post_input[0] if len(post_input) == 1 else post_input
465+
return post_args, kwargs
406466

407-
def post_forward(self, module, input, output):
467+
def post_forward(self, module, args, kwargs, output):
408468
'''Register a backward-hook to the resulting tensor right after the forward.'''
409469
hook_ref = weakref.ref(self)
410470

471+
single = not isinstance(output, tuple)
472+
if single:
473+
output = (output,)
474+
475+
grad_fn, post_output, output_tensor_mask = self._inject_grad_fn(output)
476+
if grad_fn is None:
477+
return
478+
411479
@functools.wraps(self.pre_backward)
412480
def wrapper(grad_input, grad_output):
413481
hook = hook_ref()
414482
if hook is not None and hook.active:
415-
return hook.pre_backward(module, grad_input, grad_output)
483+
return hook.pre_backward(
484+
module,
485+
grad_input,
486+
tuple(uncompress(
487+
repeat(None),
488+
output_tensor_mask,
489+
grad_output
490+
))
491+
)
416492
return None
417493

418-
if not isinstance(output, tuple):
419-
output = (output,)
494+
# register the output tensor gradient hook
495+
self.tensor_handles.append(grad_fn.register_hook(wrapper))
420496

421-
# only if gradient required
422-
if output[0].grad_fn is not None:
423-
# register the output tensor gradient hook
424-
self.tensor_handles.append(
425-
output[0].grad_fn.register_hook(wrapper)
426-
)
427-
return output[0] if len(output) == 1 else output
497+
if single:
498+
return post_output[0]
499+
return post_output
428500

429501
def pre_backward(self, module, grad_input, grad_output):
430502
'''Store the grad_output for the backward hook'''
431503
self.stored_tensors['grad_output'] = grad_output
432504

433-
def forward(self, module, input, output):
505+
def forward(self, module, args, kwargs, output):
434506
'''Hook applied during forward-pass'''
435507

436508
def backward(self, module, grad_input, grad_output):
@@ -449,11 +521,14 @@ def remove(self):
449521

450522
def register(self, module):
451523
'''Register this instance by registering all hooks to the supplied module.'''
524+
# assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
525+
forward_params = signature(self.forward).parameters
526+
with_kwargs = len(forward_params) != 3 and list(forward_params)[2] != 'output'
452527
return RemovableHandleList([
453528
RemovableHandle(self),
454-
module.register_forward_pre_hook(self.pre_forward),
455-
module.register_forward_hook(self.post_forward),
456-
module.register_forward_hook(self.forward),
529+
module.register_forward_pre_hook(self.pre_forward, with_kwargs=True),
530+
module.register_forward_hook(self.post_forward, with_kwargs=True),
531+
module.register_forward_hook(self.forward, with_kwargs=with_kwargs),
457532
])
458533

459534

@@ -517,31 +592,61 @@ def __init__(
517592
self.gradient_mapper = gradient_mapper
518593
self.reducer = reducer
519594

520-
def forward(self, module, input, output):
595+
def forward(self, module, args, kwargs, output):
521596
'''Forward hook to save module in-/outputs.'''
522-
self.stored_tensors['input'] = input
597+
self.stored_tensors['input'] = args
598+
self.stored_tensors['kwargs'] = kwargs
523599

524600
def backward(self, module, grad_input, grad_output):
525601
'''Backward hook to compute LRP based on the class attributes.'''
526-
original_input = self.stored_tensors['input'][0].clone()
602+
input_mask = [elem is not None for elem in self.stored_tensors['input']]
603+
output_mask = [elem is not None for elem in grad_output]
604+
cgrad_output = tuple(compress(grad_output, output_mask))
605+
606+
original_inputs = [tensor.clone() for tensor in self.stored_tensors['input']]
607+
kwargs = self.stored_tensors['kwargs']
527608
inputs = []
528609
outputs = []
529610
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
530-
input = in_mod(original_input).requires_grad_()
611+
mod_args = (in_mod(tensor).requires_grad_() for tensor in compress(original_inputs, input_mask))
612+
args = tuple(uncompress(original_inputs, input_mask, mod_args))
531613
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
532-
output = modified.forward(input)
533-
output = out_mod(output)
534-
inputs.append(input)
614+
output = modified.forward(*args, **kwargs)
615+
if not isinstance(output, tuple):
616+
output = (output,)
617+
output = tuple(out_mod(tensor) for tensor in compress(output, output_mask))
618+
inputs.append(compress(args, input_mask))
535619
outputs.append(output)
536-
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
537-
gradients = torch.autograd.grad(
538-
outputs,
539-
inputs,
620+
621+
inputs = list(zip(*inputs))
622+
outputs = list(zip(*outputs))
623+
input_struct = [len(elem) for elem in inputs]
624+
output_struct = [len(elem) for elem in outputs]
625+
626+
grad_outputs = tuple(
627+
self.gradient_mapper(gradout, outs)
628+
for gradout, outs in zip(cgrad_output, outputs)
629+
)
630+
inputs_flat = tuple(chain.from_iterable(inputs))
631+
outputs_flat = tuple(chain.from_iterable(outputs))
632+
if not all(isinstance(elem, torch.Tensor) for elem in grad_outputs):
633+
# if there is only a single output modifier, grad_outputs may contain tensors
634+
grad_outputs = tuple(chain.from_iterable(grad_outputs))
635+
636+
gradients_flat = torch.autograd.grad(
637+
outputs_flat,
638+
inputs_flat,
540639
grad_outputs=grad_outputs,
541-
create_graph=grad_output[0].requires_grad
640+
create_graph=any(tensor.requires_grad for tensor in cgrad_output)
542641
)
543-
relevance = self.reducer(inputs, gradients)
544-
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
642+
643+
# input_it = iter(inputs)
644+
# inputs_re = [tuple(islice(input_it, size)) for size in input_struct]
645+
gradient_it = iter(gradients_flat)
646+
gradients = [tuple(islice(gradient_it, size)) for size in input_struct]
647+
648+
relevances = (self.reducer(inp, grad) for inp, grad in zip(inputs, gradients))
649+
return tuple(uncompress(repeat(None), input_mask, relevances))
545650

546651
def copy(self):
547652
'''Return a copy of this hook.

0 commit comments

Comments
 (0)