| 35 | |
| 36 | class AutoWrappedModule(AutoTorchModule): |
| 37 | def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs): |
| 38 | super().__init__() |
| 39 | self.module = module.to(dtype=offload_dtype, device=offload_device) |
| 40 | self.offload_dtype = offload_dtype |
| 41 | self.offload_device = offload_device |
| 42 | self.onload_dtype = onload_dtype |
| 43 | self.onload_device = onload_device |
| 44 | self.computation_dtype = computation_dtype |
| 45 | self.computation_device = computation_device |
| 46 | self.vram_limit = vram_limit |
| 47 | self.state = 0 |
| 48 | |
| 49 | def forward(self, *args, **kwargs): |
| 50 | if self.state == 2: |