| 96 | self.diffusion.to(self.args.device) |
| 97 | |
| 98 | def load_cond_fn(self) -> None: |
| 99 | if not self.args.guidance: |
| 100 | self.cond_fn = None |
| 101 | return |
| 102 | if self.args.g_loss == "mse": |
| 103 | cond_fn_cls = MSEGuidance |
| 104 | elif self.args.g_loss == "w_mse": |
| 105 | cond_fn_cls = WeightedMSEGuidance |
| 106 | else: |
| 107 | raise ValueError(self.args.g_loss) |
| 108 | self.cond_fn = cond_fn_cls( |
| 109 | self.args.g_scale, |
| 110 | self.args.g_start, |
| 111 | self.args.g_stop, |
| 112 | self.args.g_space, |
| 113 | self.args.g_repeat, |
| 114 | ) |
| 115 | |
| 116 | @overload |
| 117 | def load_pipeline(self) -> None: ... |