MCPcopy
hub / github.com/ermongroup/ddim / sample

Method sample

runners/diffusion.py:192–242  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

190 data_start = time.time()
191
192 def sample(self):
193 model = Model(self.config)
194
195 if not self.args.use_pretrained:
196 if getattr(self.config.sampling, "ckpt_id", None) is None:
197 states = torch.load(
198 os.path.join(self.args.log_path, "ckpt.pth"),
199 map_location=self.config.device,
200 )
201 else:
202 states = torch.load(
203 os.path.join(
204 self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pth"
205 ),
206 map_location=self.config.device,
207 )
208 model = model.to(self.device)
209 model = torch.nn.DataParallel(model)
210 model.load_state_dict(states[0], strict=True)
211
212 if self.config.model.ema:
213 ema_helper = EMAHelper(mu=self.config.model.ema_rate)
214 ema_helper.register(model)
215 ema_helper.load_state_dict(states[-1])
216 ema_helper.ema(model)
217 else:
218 ema_helper = None
219 else:
220 # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion
221 if self.config.data.dataset == "CIFAR10":
222 name = "cifar10"
223 elif self.config.data.dataset == "LSUN":
224 name = f"lsun_{self.config.data.category}"
225 else:
226 raise ValueError
227 ckpt = get_ckpt_path(f"ema_{name}")
228 print("Loading checkpoint {}".format(ckpt))
229 model.load_state_dict(torch.load(ckpt, map_location=self.device))
230 model.to(self.device)
231 model = torch.nn.DataParallel(model)
232
233 model.eval()
234
235 if self.args.fid:
236 self.sample_fid(model)
237 elif self.args.interpolation:
238 self.sample_interpolation(model)
239 elif self.args.sequence:
240 self.sample_sequence(model)
241 else:
242 raise NotImplementedError("Sample procedeure not defined")
243
244 def sample_fid(self, model):
245 config = self.config

Callers 1

mainFunction · 0.95

Calls 9

load_state_dictMethod · 0.95
registerMethod · 0.95
emaMethod · 0.95
sample_fidMethod · 0.95
sample_interpolationMethod · 0.95
sample_sequenceMethod · 0.95
ModelClass · 0.90
EMAHelperClass · 0.90
get_ckpt_pathFunction · 0.90

Tested by

no test coverage detected