(
self,
model: ControlLDM,
x: torch.Tensor,
model_t: torch.Tensor,
t: torch.Tensor,
cond: Dict[str, torch.Tensor],
uncond: Optional[Dict[str, torch.Tensor]],
cfg_scale: float,
)
| 160 | |
| 161 | @torch.no_grad() |
| 162 | def p_sample( |
| 163 | self, |
| 164 | model: ControlLDM, |
| 165 | x: torch.Tensor, |
| 166 | model_t: torch.Tensor, |
| 167 | t: torch.Tensor, |
| 168 | cond: Dict[str, torch.Tensor], |
| 169 | uncond: Optional[Dict[str, torch.Tensor]], |
| 170 | cfg_scale: float, |
| 171 | ) -> torch.Tensor: |
| 172 | # predict x_0 |
| 173 | model_output = self.apply_model(model, x, model_t, cond, uncond, cfg_scale) |
| 174 | if self.parameterization == "eps": |
| 175 | pred_x0 = self._predict_xstart_from_eps(x, t, model_output) |
| 176 | else: |
| 177 | pred_x0 = self._predict_xstart_from_v(x, t, model_output) |
| 178 | # calculate mean and variance of next state |
| 179 | mean, variance = self.q_posterior_mean_variance(pred_x0, x, t) |
| 180 | # sample next state |
| 181 | noise = torch.randn_like(x) |
| 182 | nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) |
| 183 | x_prev = mean + nonzero_mask * torch.sqrt(variance) * noise |
| 184 | return x_prev |
| 185 | |
| 186 | @torch.no_grad() |
| 187 | def sample( |
no test coverage detected