MCPcopy
hub / github.com/ermongroup/ddim / train

Method train

runners/diffusion.py:98–190  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

96 self.logvar = posterior_variance.clamp(min=1e-20).log()
97
98 def train(self):
99 args, config = self.args, self.config
100 tb_logger = self.config.tb_logger
101 dataset, test_dataset = get_dataset(args, config)
102 train_loader = data.DataLoader(
103 dataset,
104 batch_size=config.training.batch_size,
105 shuffle=True,
106 num_workers=config.data.num_workers,
107 )
108 model = Model(config)
109
110 model = model.to(self.device)
111 model = torch.nn.DataParallel(model)
112
113 optimizer = get_optimizer(self.config, model.parameters())
114
115 if self.config.model.ema:
116 ema_helper = EMAHelper(mu=self.config.model.ema_rate)
117 ema_helper.register(model)
118 else:
119 ema_helper = None
120
121 start_epoch, step = 0, 0
122 if self.args.resume_training:
123 states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
124 model.load_state_dict(states[0])
125
126 states[1]["param_groups"][0]["eps"] = self.config.optim.eps
127 optimizer.load_state_dict(states[1])
128 start_epoch = states[2]
129 step = states[3]
130 if self.config.model.ema:
131 ema_helper.load_state_dict(states[4])
132
133 for epoch in range(start_epoch, self.config.training.n_epochs):
134 data_start = time.time()
135 data_time = 0
136 for i, (x, y) in enumerate(train_loader):
137 n = x.size(0)
138 data_time += time.time() - data_start
139 model.train()
140 step += 1
141
142 x = x.to(self.device)
143 x = data_transform(self.config, x)
144 e = torch.randn_like(x)
145 b = self.betas
146
147 # antithetic sampling
148 t = torch.randint(
149 low=0, high=self.num_timesteps, size=(n // 2 + 1,)
150 ).to(self.device)
151 t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
152 loss = loss_registry[config.model.type](model, x, t, e, b)
153
154 tb_logger.add_scalar("loss", loss, global_step=step)
155

Callers 1

mainFunction · 0.95

Calls 9

registerMethod · 0.95
load_state_dictMethod · 0.95
updateMethod · 0.95
state_dictMethod · 0.95
get_datasetFunction · 0.90
ModelClass · 0.90
get_optimizerFunction · 0.90
EMAHelperClass · 0.90
data_transformFunction · 0.90

Tested by

no test coverage detected