(self, *args, **kwargs)
| 47 | self.state = 0 |
| 48 | |
| 49 | def forward(self, *args, **kwargs): |
| 50 | if self.state == 2: |
| 51 | module = self.module |
| 52 | else: |
| 53 | if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: |
| 54 | module = self.module |
| 55 | elif self.vram_limit is not None and self.check_free_vram(): |
| 56 | self.keep() |
| 57 | module = self.module |
| 58 | else: |
| 59 | module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) |
| 60 | return module(*args, **kwargs) |
| 61 | |
| 62 | |
| 63 | class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule): |
nothing calls this directly
no test coverage detected