| 159 | self.model.convert_to_fp16() |
| 160 | |
| 161 | def run_loop(self): |
| 162 | while ( |
| 163 | not self.lr_anneal_steps |
| 164 | or self.step + self.resume_step < self.lr_anneal_steps |
| 165 | ): |
| 166 | batch, cond = next(self.data) |
| 167 | self.run_step(batch, cond) |
| 168 | if self.step % self.log_interval == 0: |
| 169 | logger.dumpkvs() |
| 170 | if self.step % self.save_interval == 0: |
| 171 | self.save() |
| 172 | # Run for a finite amount of time in integration tests. |
| 173 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: |
| 174 | return |
| 175 | self.step += 1 |
| 176 | # Save the last checkpoint if it wasn't already saved. |
| 177 | if (self.step - 1) % self.save_interval != 0: |
| 178 | self.save() |
| 179 | |
| 180 | def run_step(self, batch, cond): |
| 181 | self.forward_backward(batch, cond) |