(self)
| 55 | } |
| 56 | |
| 57 | def apply_torch_model(self): |
| 58 | def disable_checkpoint(self): |
| 59 | if getattr(self, "use_checkpoint", False) == True: |
| 60 | self.use_checkpoint = False |
| 61 | if getattr(self, "checkpoint", False) == True: |
| 62 | self.checkpoint = False |
| 63 | |
| 64 | self.unet.apply(disable_checkpoint) |
| 65 | self.set_unet("None") |
| 66 | |
| 67 | def set_unet(self, ckpt: str): |
| 68 | # TODO test if using this with TRT works |
no test coverage detected