(
self,
*,
model,
diffusion,
data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=1e-3,
schedule_sampler=None,
weight_decay=0.0,
lr_anneal_steps=0,
)
| 21 | |
| 22 | class TrainLoop: |
| 23 | def __init__( |
| 24 | self, |
| 25 | *, |
| 26 | model, |
| 27 | diffusion, |
| 28 | data, |
| 29 | batch_size, |
| 30 | microbatch, |
| 31 | lr, |
| 32 | ema_rate, |
| 33 | log_interval, |
| 34 | save_interval, |
| 35 | resume_checkpoint, |
| 36 | use_fp16=False, |
| 37 | fp16_scale_growth=1e-3, |
| 38 | schedule_sampler=None, |
| 39 | weight_decay=0.0, |
| 40 | lr_anneal_steps=0, |
| 41 | ): |
| 42 | self.model = model |
| 43 | self.diffusion = diffusion |
| 44 | self.data = data |
| 45 | self.batch_size = batch_size |
| 46 | self.microbatch = microbatch if microbatch > 0 else batch_size |
| 47 | self.lr = lr |
| 48 | self.ema_rate = ( |
| 49 | [ema_rate] |
| 50 | if isinstance(ema_rate, float) |
| 51 | else [float(x) for x in ema_rate.split(",")] |
| 52 | ) |
| 53 | self.log_interval = log_interval |
| 54 | self.save_interval = save_interval |
| 55 | self.resume_checkpoint = resume_checkpoint |
| 56 | self.use_fp16 = use_fp16 |
| 57 | self.fp16_scale_growth = fp16_scale_growth |
| 58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) |
| 59 | self.weight_decay = weight_decay |
| 60 | self.lr_anneal_steps = lr_anneal_steps |
| 61 | |
| 62 | self.step = 0 |
| 63 | self.resume_step = 0 |
| 64 | self.global_batch = self.batch_size * dist.get_world_size() |
| 65 | |
| 66 | self.sync_cuda = th.cuda.is_available() |
| 67 | |
| 68 | self._load_and_sync_parameters() |
| 69 | self.mp_trainer = MixedPrecisionTrainer( |
| 70 | model=self.model, |
| 71 | use_fp16=self.use_fp16, |
| 72 | fp16_scale_growth=fp16_scale_growth, |
| 73 | ) |
| 74 | |
| 75 | self.opt = AdamW( |
| 76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay |
| 77 | ) |
| 78 | if self.resume_step: |
| 79 | self._load_optimizer_state() |
| 80 | # Model was resumed, either due to a restart or a checkpoint |
nothing calls this directly
no test coverage detected