(self)
| 190 | data_start = time.time() |
| 191 | |
| 192 | def sample(self): |
| 193 | model = Model(self.config) |
| 194 | |
| 195 | if not self.args.use_pretrained: |
| 196 | if getattr(self.config.sampling, "ckpt_id", None) is None: |
| 197 | states = torch.load( |
| 198 | os.path.join(self.args.log_path, "ckpt.pth"), |
| 199 | map_location=self.config.device, |
| 200 | ) |
| 201 | else: |
| 202 | states = torch.load( |
| 203 | os.path.join( |
| 204 | self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pth" |
| 205 | ), |
| 206 | map_location=self.config.device, |
| 207 | ) |
| 208 | model = model.to(self.device) |
| 209 | model = torch.nn.DataParallel(model) |
| 210 | model.load_state_dict(states[0], strict=True) |
| 211 | |
| 212 | if self.config.model.ema: |
| 213 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) |
| 214 | ema_helper.register(model) |
| 215 | ema_helper.load_state_dict(states[-1]) |
| 216 | ema_helper.ema(model) |
| 217 | else: |
| 218 | ema_helper = None |
| 219 | else: |
| 220 | # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion |
| 221 | if self.config.data.dataset == "CIFAR10": |
| 222 | name = "cifar10" |
| 223 | elif self.config.data.dataset == "LSUN": |
| 224 | name = f"lsun_{self.config.data.category}" |
| 225 | else: |
| 226 | raise ValueError |
| 227 | ckpt = get_ckpt_path(f"ema_{name}") |
| 228 | print("Loading checkpoint {}".format(ckpt)) |
| 229 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) |
| 230 | model.to(self.device) |
| 231 | model = torch.nn.DataParallel(model) |
| 232 | |
| 233 | model.eval() |
| 234 | |
| 235 | if self.args.fid: |
| 236 | self.sample_fid(model) |
| 237 | elif self.args.interpolation: |
| 238 | self.sample_interpolation(model) |
| 239 | elif self.args.sequence: |
| 240 | self.sample_sequence(model) |
| 241 | else: |
| 242 | raise NotImplementedError("Sample procedeure not defined") |
| 243 | |
| 244 | def sample_fid(self, model): |
| 245 | config = self.config |
no test coverage detected