| 201 | def pre_hook(_mod, _inputs): |
| 202 | nesting[0] += 1 |
| 203 | def post_hook(mod, module_inputs, outputs): |
| 204 | nesting[0] -= 1 |
| 205 | if nesting[0] <= max_nesting: |
| 206 | module_inputs = list(module_inputs) if isinstance(module_inputs, (tuple, list)) else [module_inputs] |
| 207 | module_inputs = [t for t in module_inputs if isinstance(t, torch.Tensor)] |
| 208 | if isinstance(outputs, (tuple, list)): |
| 209 | outputs = list(outputs) |
| 210 | elif isinstance(outputs, dict): |
| 211 | outputs = list(outputs.values()) |
| 212 | else: |
| 213 | outputs = [outputs] |
| 214 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] |
| 215 | entries.append(dnnlib.EasyDict(mod=mod, inputs=module_inputs, outputs=outputs)) |
| 216 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] |
| 217 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] |
| 218 | |