| 78 | |
| 79 | |
| 80 | class MultiplyHook(ModelHook): |
| 81 | def __init__(self, value: int): |
| 82 | super().__init__() |
| 83 | self.value = value |
| 84 | |
| 85 | def pre_forward(self, module, *args, **kwargs): |
| 86 | logger.debug("MultiplyHook pre_forward") |
| 87 | args = ((x * self.value) if torch.is_tensor(x) else x for x in args) |
| 88 | return args, kwargs |
| 89 | |
| 90 | def post_forward(self, module, output): |
| 91 | logger.debug("MultiplyHook post_forward") |
| 92 | return output |
| 93 | |
| 94 | def __repr__(self): |
| 95 | return f"MultiplyHook(value={self.value})" |
| 96 | |
| 97 | |
| 98 | class StatefulAddHook(ModelHook): |
no outgoing calls
searching dependent graphs…