MCPcopy Index your code
hub / github.com/openai/improved-diffusion / __init__

Method __init__

improved_diffusion/train_util.py:30–112  ·  view source on GitHub ↗
(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    )

Source from the content-addressed store, hash-verified

28
29class TrainLoop:
30 def __init__(
31 self,
32 *,
33 model,
34 diffusion,
35 data,
36 batch_size,
37 microbatch,
38 lr,
39 ema_rate,
40 log_interval,
41 save_interval,
42 resume_checkpoint,
43 use_fp16=False,
44 fp16_scale_growth=1e-3,
45 schedule_sampler=None,
46 weight_decay=0.0,
47 lr_anneal_steps=0,
48 ):
49 self.model = model
50 self.diffusion = diffusion
51 self.data = data
52 self.batch_size = batch_size
53 self.microbatch = microbatch if microbatch > 0 else batch_size
54 self.lr = lr
55 self.ema_rate = (
56 [ema_rate]
57 if isinstance(ema_rate, float)
58 else [float(x) for x in ema_rate.split(",")]
59 )
60 self.log_interval = log_interval
61 self.save_interval = save_interval
62 self.resume_checkpoint = resume_checkpoint
63 self.use_fp16 = use_fp16
64 self.fp16_scale_growth = fp16_scale_growth
65 self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
66 self.weight_decay = weight_decay
67 self.lr_anneal_steps = lr_anneal_steps
68
69 self.step = 0
70 self.resume_step = 0
71 self.global_batch = self.batch_size * dist.get_world_size()
72
73 self.model_params = list(self.model.parameters())
74 self.master_params = self.model_params
75 self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
76 self.sync_cuda = th.cuda.is_available()
77
78 self._load_and_sync_parameters()
79 if self.use_fp16:
80 self._setup_fp16()
81
82 self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
83 if self.resume_step:
84 self._load_optimizer_state()
85 # Model was resumed, either due to a restart or a checkpoint
86 # being specified at the command line.
87 self.ema_params = [

Callers

nothing calls this directly

Calls 5

_setup_fp16Method · 0.95
_load_optimizer_stateMethod · 0.95
_load_ema_parametersMethod · 0.95
UniformSamplerClass · 0.85

Tested by

no test coverage detected