Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original device on exit. Args: device (`str` or `torch.Device`): Device to move the `modules` to. offload (`bool`): Flag to enable offloading.
(*modules: torch.nn.Module | DiffusionPipeline, device: str | torch.device, offload: bool = True)
| 420 | |
| 421 | @contextmanager |
| 422 | def offload_models(*modules: torch.nn.Module | DiffusionPipeline, device: str | torch.device, offload: bool = True): |
| 423 | """ |
| 424 | Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original |
| 425 | device on exit. |
| 426 | |
| 427 | Args: |
| 428 | device (`str` or `torch.Device`): Device to move the `modules` to. |
| 429 | offload (`bool`): Flag to enable offloading. |
| 430 | """ |
| 431 | if offload: |
| 432 | is_model = not any(isinstance(m, DiffusionPipeline) for m in modules) |
| 433 | # record where each module was |
| 434 | if is_model: |
| 435 | original_devices = [next(m.parameters()).device for m in modules] |
| 436 | else: |
| 437 | assert len(modules) == 1 |
| 438 | # For DiffusionPipeline, wrap the device in a list to make it iterable |
| 439 | original_devices = [modules[0].device] |
| 440 | # move to target device |
| 441 | for m in modules: |
| 442 | m.to(device) |
| 443 | |
| 444 | try: |
| 445 | yield |
| 446 | finally: |
| 447 | if offload: |
| 448 | # move back to original devices |
| 449 | for m, orig_dev in zip(modules, original_devices): |
| 450 | m.to(orig_dev) |
| 451 | |
| 452 | |
| 453 | def parse_buckets_string(buckets_str): |