(self, *args, **kwargs)
| 31 | self.state = 1 |
| 32 | |
| 33 | def forward(self, *args, **kwargs): |
| 34 | if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: |
| 35 | module = self.module |
| 36 | else: |
| 37 | module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) |
| 38 | return module(*args, **kwargs) |
| 39 | |
| 40 | |
| 41 | class AutoWrappedLinear(torch.nn.Linear): |