MCPcopy
hub / github.com/openai/guided-diffusion / __init__

Method __init__

guided_diffusion/train_util.py:23–108  ·  view source on GitHub ↗
(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    )

Source from the content-addressed store, hash-verified

21
22class TrainLoop:
23 def __init__(
24 self,
25 *,
26 model,
27 diffusion,
28 data,
29 batch_size,
30 microbatch,
31 lr,
32 ema_rate,
33 log_interval,
34 save_interval,
35 resume_checkpoint,
36 use_fp16=False,
37 fp16_scale_growth=1e-3,
38 schedule_sampler=None,
39 weight_decay=0.0,
40 lr_anneal_steps=0,
41 ):
42 self.model = model
43 self.diffusion = diffusion
44 self.data = data
45 self.batch_size = batch_size
46 self.microbatch = microbatch if microbatch > 0 else batch_size
47 self.lr = lr
48 self.ema_rate = (
49 [ema_rate]
50 if isinstance(ema_rate, float)
51 else [float(x) for x in ema_rate.split(",")]
52 )
53 self.log_interval = log_interval
54 self.save_interval = save_interval
55 self.resume_checkpoint = resume_checkpoint
56 self.use_fp16 = use_fp16
57 self.fp16_scale_growth = fp16_scale_growth
58 self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
59 self.weight_decay = weight_decay
60 self.lr_anneal_steps = lr_anneal_steps
61
62 self.step = 0
63 self.resume_step = 0
64 self.global_batch = self.batch_size * dist.get_world_size()
65
66 self.sync_cuda = th.cuda.is_available()
67
68 self._load_and_sync_parameters()
69 self.mp_trainer = MixedPrecisionTrainer(
70 model=self.model,
71 use_fp16=self.use_fp16,
72 fp16_scale_growth=fp16_scale_growth,
73 )
74
75 self.opt = AdamW(
76 self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
77 )
78 if self.resume_step:
79 self._load_optimizer_state()
80 # Model was resumed, either due to a restart or a checkpoint

Callers

nothing calls this directly

Calls 5

_load_optimizer_stateMethod · 0.95
_load_ema_parametersMethod · 0.95
UniformSamplerClass · 0.85

Tested by

no test coverage detected