Context manager to globally disable weight initialization to speed up loading large models. To do that, all the torch.nn.init function are all replaced with skip.
()
| 211 | |
| 212 | @contextmanager |
| 213 | def no_init_weights(): |
| 214 | """ |
| 215 | Context manager to globally disable weight initialization to speed up loading large models. To do that, all the |
| 216 | torch.nn.init function are all replaced with skip. |
| 217 | """ |
| 218 | |
| 219 | def _skip_init(*args, **kwargs): |
| 220 | pass |
| 221 | |
| 222 | for name, init_func in TORCH_INIT_FUNCTIONS.items(): |
| 223 | setattr(torch.nn.init, name, _skip_init) |
| 224 | try: |
| 225 | yield |
| 226 | finally: |
| 227 | # Restore the original initialization functions |
| 228 | for name, init_func in TORCH_INIT_FUNCTIONS.items(): |
| 229 | setattr(torch.nn.init, name, init_func) |
| 230 | |
| 231 | |
| 232 | class ModelMixin(torch.nn.Module, PushToHubMixin): |
no outgoing calls
no test coverage detected
searching dependent graphs…