| 27 | |
| 28 | |
| 29 | class TrainLoop: |
| 30 | def __init__( |
| 31 | self, |
| 32 | *, |
| 33 | model, |
| 34 | diffusion, |
| 35 | data, |
| 36 | batch_size, |
| 37 | microbatch, |
| 38 | lr, |
| 39 | ema_rate, |
| 40 | log_interval, |
| 41 | save_interval, |
| 42 | resume_checkpoint, |
| 43 | use_fp16=False, |
| 44 | fp16_scale_growth=1e-3, |
| 45 | schedule_sampler=None, |
| 46 | weight_decay=0.0, |
| 47 | lr_anneal_steps=0, |
| 48 | ): |
| 49 | self.model = model |
| 50 | self.diffusion = diffusion |
| 51 | self.data = data |
| 52 | self.batch_size = batch_size |
| 53 | self.microbatch = microbatch if microbatch > 0 else batch_size |
| 54 | self.lr = lr |
| 55 | self.ema_rate = ( |
| 56 | [ema_rate] |
| 57 | if isinstance(ema_rate, float) |
| 58 | else [float(x) for x in ema_rate.split(",")] |
| 59 | ) |
| 60 | self.log_interval = log_interval |
| 61 | self.save_interval = save_interval |
| 62 | self.resume_checkpoint = resume_checkpoint |
| 63 | self.use_fp16 = use_fp16 |
| 64 | self.fp16_scale_growth = fp16_scale_growth |
| 65 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) |
| 66 | self.weight_decay = weight_decay |
| 67 | self.lr_anneal_steps = lr_anneal_steps |
| 68 | |
| 69 | self.step = 0 |
| 70 | self.resume_step = 0 |
| 71 | self.global_batch = self.batch_size * dist.get_world_size() |
| 72 | |
| 73 | self.model_params = list(self.model.parameters()) |
| 74 | self.master_params = self.model_params |
| 75 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE |
| 76 | self.sync_cuda = th.cuda.is_available() |
| 77 | |
| 78 | self._load_and_sync_parameters() |
| 79 | if self.use_fp16: |
| 80 | self._setup_fp16() |
| 81 | |
| 82 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) |
| 83 | if self.resume_step: |
| 84 | self._load_optimizer_state() |
| 85 | # Model was resumed, either due to a restart or a checkpoint |
| 86 | # being specified at the command line. |