(model: torch.nn.Module, name: str)
| 361 | |
| 362 | |
| 363 | def _find_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]: |
| 364 | if name == "": |
| 365 | return model |
| 366 | first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") |
| 367 | if first_atom == "*": |
| 368 | if not isinstance(model, torch.nn.ModuleList): |
| 369 | raise ValueError("Wildcard '*' can only be used with ModuleList") |
| 370 | submodules = [] |
| 371 | for submodule in model: |
| 372 | subsubmodules = _find_submodule_by_name(submodule, remaining_name) |
| 373 | if not isinstance(subsubmodules, list): |
| 374 | subsubmodules = [subsubmodules] |
| 375 | submodules.extend(subsubmodules) |
| 376 | return submodules |
| 377 | else: |
| 378 | if hasattr(model, first_atom): |
| 379 | submodule = getattr(model, first_atom) |
| 380 | return _find_submodule_by_name(submodule, remaining_name) |
| 381 | else: |
| 382 | raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") |
no test coverage detected
searching dependent graphs…