| 86 | |
| 87 | @torch.no_grad() |
| 88 | def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8): |
| 89 | b, *_ = y_cond.shape |
| 90 | |
| 91 | assert self.num_timesteps > sample_num, 'num_timesteps must greater than sample_num' |
| 92 | sample_inter = (self.num_timesteps//sample_num) |
| 93 | |
| 94 | y_t = default(y_t, lambda: torch.randn_like(y_cond)) |
| 95 | ret_arr = y_t |
| 96 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): |
| 97 | t = torch.full((b,), i, device=y_cond.device, dtype=torch.long) |
| 98 | y_t = self.p_sample(y_t, t, y_cond=y_cond) |
| 99 | if mask is not None: |
| 100 | y_t = y_0*(1.-mask) + mask*y_t |
| 101 | if i % sample_inter == 0: |
| 102 | ret_arr = torch.cat([ret_arr, y_t], dim=0) |
| 103 | return y_t, ret_arr |
| 104 | |
| 105 | def forward(self, y_0, y_cond=None, mask=None, noise=None): |
| 106 | # sampling from p(gammas) |