Skip to content

Commit 8f11583

Browse files
committed
Core: Multiple input/param gradient modification
- change the core Hook to support the modification of multiple inputs and params - for this, now each input and parameter that requires a gradient will be hooked, and a backward, which is aware of which the current 'sink' is, will be called for each - use View instead of custom Identity to produce a .grad_fn Note: - this may be a breaking change for custom hooks based on the old implementation TODO: - finish implementation: - parameters have no grad_fn, and we cannot simply overwrite them with a view; hooking directly with tensor hooks is problematic when the parameters are used in different functions - there may be potentially a better approach than calling the backward function once per 'sink', although the current implementation may allow for better modularity - multiple outputs are still not supported, it may be worth to think how to do it, however, it may also be better to do this at a later stage - implement tests - new tests for the new functionality: multiple inputs and params in hooks - fix old tests that assume the use of Identity and are not sink-aware - add documentation
1 parent d46f3e7 commit 8f11583

File tree

2 files changed

+50
-51
lines changed

2 files changed

+50
-51
lines changed

src/zennit/core.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -356,19 +356,6 @@ def collect_leaves(module):
356356
yield module
357357

358358

359-
class Identity(torch.autograd.Function):
360-
'''Identity to add a grad_fn to a tensor, so a backward hook can be applied.'''
361-
@staticmethod
362-
def forward(ctx, *inputs):
363-
'''Forward identity.'''
364-
return inputs
365-
366-
@staticmethod
367-
def backward(ctx, *grad_outputs):
368-
'''Backward identity.'''
369-
return grad_outputs
370-
371-
372359
class Hook:
373360
'''Base class for hooks to be used to compute layer-wise attributions.'''
374361
def __init__(self):
@@ -381,29 +368,41 @@ def pre_forward(self, module, input):
381368
hook_ref = weakref.ref(self)
382369

383370
@functools.wraps(self.backward)
384-
def wrapper(grad_input, grad_output):
371+
def wrapper(grad_input, grad_output, grad_sink):
385372
hook = hook_ref()
386373
if hook is not None and hook.active:
387-
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
374+
result = hook.backward(module, grad_output, hook.stored_tensors['grad_output'], grad_sink=grad_sink)
375+
if not isinstance(result, tuple):
376+
result = (result,)
377+
if grad_input is None:
378+
return result[0]
379+
return result
388380
return None
389381

390382
if not isinstance(input, tuple):
391383
input = (input,)
392384

393-
# only if gradient required
394-
if input[0].requires_grad:
395-
# add identity to ensure .grad_fn exists
396-
post_input = Identity.apply(*input)
385+
post_input = tuple(tensor.view_as(tensor) for tensor in input)
386+
387+
# hook required gradient sinks
388+
for grad_sink, tensor in enumerate(post_input):
397389
# register the input tensor gradient hook
398-
self.tensor_handles.append(
399-
post_input[0].grad_fn.register_hook(wrapper)
400-
)
401-
# work around to support in-place operations
402-
post_input = tuple(elem.clone() for elem in post_input)
403-
else:
404-
# no gradient required
405-
post_input = input
406-
return post_input[0] if len(post_input) == 1 else post_input
390+
if tensor.grad_fn is not None:
391+
# grad_fn for inputs is here the view function applied above
392+
self.tensor_handles.append(
393+
tensor.grad_fn.register_hook(functools.partial(wrapper, grad_sink=grad_sink))
394+
)
395+
# hook required gradient sinks
396+
for grad_sink, tensor in module.named_parameters():
397+
if tensor.requires_grad:
398+
# TODO: use grad_fn (need to store parameter views for the model...), otherwise the hook could be
399+
# called for unrelated gradients
400+
self.tensor_handles.append(
401+
tensor.register_hook(functools.partial(wrapper, None, grad_sink=grad_sink))
402+
)
403+
404+
# torch.nn.Module converts single tensors to tuples anyway, so we can always return a tuple here
405+
return post_input
407406

408407
def post_forward(self, module, input, output):
409408
'''Register a backward-hook to the resulting tensor right after the forward.'''
@@ -413,28 +412,28 @@ def post_forward(self, module, input, output):
413412
def wrapper(grad_input, grad_output):
414413
hook = hook_ref()
415414
if hook is not None and hook.active:
416-
return hook.pre_backward(module, grad_input, grad_output)
415+
return hook.pre_backward(module, grad_output)
417416
return None
418417

419-
if not isinstance(output, tuple):
420-
output = (output,)
418+
hookable_output = output
419+
if not isinstance(hookable_output, tuple):
420+
hookable_output = (hookable_output,)
421421

422422
# only if gradient required
423-
if output[0].grad_fn is not None:
423+
if hookable_output[0].requires_grad:
424424
# register the output tensor gradient hook
425425
self.tensor_handles.append(
426-
output[0].grad_fn.register_hook(wrapper)
426+
hookable_output[0].grad_fn.register_hook(wrapper)
427427
)
428-
return output[0] if len(output) == 1 else output
429428

430-
def pre_backward(self, module, grad_input, grad_output):
429+
def pre_backward(self, module, grad_output):
431430
'''Store the grad_output for the backward hook'''
432431
self.stored_tensors['grad_output'] = grad_output
433432

434433
def forward(self, module, input, output):
435434
'''Hook applied during forward-pass'''
436435

437-
def backward(self, module, grad_input, grad_output):
436+
def backward(self, module, grad_input, grad_output, grad_sink):
438437
'''Hook applied during backward-pass'''
439438

440439
def copy(self):
@@ -522,18 +521,18 @@ def forward(self, module, input, output):
522521
'''Forward hook to save module in-/outputs.'''
523522
self.stored_tensors['input'] = input
524523

525-
def backward(self, module, grad_input, grad_output):
524+
def backward(self, module, grad_input, grad_output, grad_sink):
526525
'''Backward hook to compute LRP based on the class attributes.'''
527-
original_input = self.stored_tensors['input'][0].clone()
526+
original_inputs = [tensor.view_as(tensor) for tensor in self.stored_tensors['input']]
528527
inputs = []
529528
outputs = []
530529
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
531-
input = in_mod(original_input).requires_grad_()
530+
input_args = [in_mod(tensor).requires_grad_() for tensor in original_inputs]
532531
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
533-
output = modified.forward(input)
534-
output = out_mod(output)
535-
inputs.append(input)
536-
outputs.append(output)
532+
output = modified.forward(*input_args)
533+
# decide for which argument to compute the relevance
534+
inputs.append(input_args[grad_sink] if isinstance(grad_sink, int) else getattr(modified, grad_sink))
535+
outputs.append(out_mod(output))
537536
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
538537
gradients = torch.autograd.grad(
539538
outputs,
@@ -542,7 +541,7 @@ def backward(self, module, grad_input, grad_output):
542541
create_graph=grad_output[0].requires_grad
543542
)
544543
relevance = self.reducer(inputs, gradients)
545-
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
544+
return relevance
546545

547546
def copy(self):
548547
'''Return a copy of this hook.

src/zennit/rules.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ class Pass(Hook):
322322
If the rule of a layer shall not be any other, is elementwise and shall not be the gradient, the `Pass` rule simply
323323
passes upper layer relevance through to the lower layer.
324324
'''
325-
def backward(self, module, grad_input, grad_output):
325+
def backward(self, module, grad_input, grad_output, grad_sink):
326326
'''Pass through the upper gradient, skipping the one for this layer.'''
327327
return grad_output
328328

@@ -399,16 +399,16 @@ def __init__(self, stabilizer=1e-6, zero_params=None):
399399

400400
class ReLUDeconvNet(Hook):
401401
'''DeconvNet ReLU rule :cite:p:`zeiler2014visualizing`.'''
402-
def backward(self, module, grad_input, grad_output):
402+
def backward(self, module, grad_input, grad_output, grad_sink):
403403
'''Modify ReLU gradient according to DeconvNet :cite:p:`zeiler2014visualizing`.'''
404-
return (grad_output[0].clamp(min=0),)
404+
return grad_output[0].clamp(min=0)
405405

406406

407407
class ReLUGuidedBackprop(Hook):
408408
'''GuidedBackprop ReLU rule :cite:p:`springenberg2015striving`.'''
409-
def backward(self, module, grad_input, grad_output):
409+
def backward(self, module, grad_input, grad_output, grad_sink):
410410
'''Modify ReLU gradient according to GuidedBackprop :cite:p:`springenberg2015striving`.'''
411-
return (grad_input[0] * (grad_output[0] > 0.),)
411+
return grad_input[0] * (grad_output[0] > 0.)
412412

413413

414414
class ReLUBetaSmooth(Hook):
@@ -433,6 +433,6 @@ def forward(self, module, input, output):
433433
'''Remember the input for the backward pass.'''
434434
self.stored_tensors['input'] = input
435435

436-
def backward(self, module, grad_input, grad_output):
436+
def backward(self, module, grad_input, grad_output, grad_sink):
437437
'''Modify ReLU gradient to the smooth softplus gradient :cite:p:`dombrowski2019explanations`.'''
438-
return (torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0],)
438+
return torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0]

0 commit comments

Comments
 (0)