(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None)
| 995 | |
| 996 | @torch.no_grad() |
| 997 | def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, |
| 998 | img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., |
| 999 | score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, |
| 1000 | log_every_t=None): |
| 1001 | if not log_every_t: |
| 1002 | log_every_t = self.log_every_t |
| 1003 | timesteps = self.num_timesteps |
| 1004 | if batch_size is not None: |
| 1005 | b = batch_size if batch_size is not None else shape[0] |
| 1006 | shape = [batch_size] + list(shape) |
| 1007 | else: |
| 1008 | b = batch_size = shape[0] |
| 1009 | if x_T is None: |
| 1010 | img = torch.randn(shape, device=self.device) |
| 1011 | else: |
| 1012 | img = x_T |
| 1013 | intermediates = [] |
| 1014 | if cond is not None: |
| 1015 | if isinstance(cond, dict): |
| 1016 | cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else |
| 1017 | list(map(lambda x: x[:batch_size], cond[key])) for key in cond} |
| 1018 | else: |
| 1019 | cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] |
| 1020 | |
| 1021 | if start_T is not None: |
| 1022 | timesteps = min(timesteps, start_T) |
| 1023 | iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', |
| 1024 | total=timesteps) if verbose else reversed( |
| 1025 | range(0, timesteps)) |
| 1026 | if type(temperature) == float: |
| 1027 | temperature = [temperature] * timesteps |
| 1028 | |
| 1029 | for i in iterator: |
| 1030 | ts = torch.full((b,), i, device=self.device, dtype=torch.long) |
| 1031 | if self.shorten_cond_schedule: |
| 1032 | assert self.model.conditioning_key != 'hybrid' |
| 1033 | tc = self.cond_ids[ts].to(cond.device) |
| 1034 | cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) |
| 1035 | |
| 1036 | img, x0_partial = self.p_sample(img, cond, ts, |
| 1037 | clip_denoised=self.clip_denoised, |
| 1038 | quantize_denoised=quantize_denoised, return_x0=True, |
| 1039 | temperature=temperature[i], noise_dropout=noise_dropout, |
| 1040 | score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) |
| 1041 | if mask is not None: |
| 1042 | assert x0 is not None |
| 1043 | img_orig = self.q_sample(x0, ts) |
| 1044 | img = img_orig * mask + (1. - mask) * img |
| 1045 | |
| 1046 | if i % log_every_t == 0 or i == timesteps - 1: |
| 1047 | intermediates.append(x0_partial) |
| 1048 | if callback: callback(i) |
| 1049 | if img_callback: img_callback(img, i) |
| 1050 | return img, intermediates |
| 1051 | |
| 1052 | @torch.no_grad() |
| 1053 | def p_sample_loop(self, cond, shape, return_intermediates=False, |
no test coverage detected