Sample x_{t-1} from the model using DDIM. Same usage as p_sample().
(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
)
| 515 | img = out["sample"] |
| 516 | |
| 517 | def ddim_sample( |
| 518 | self, |
| 519 | model, |
| 520 | x, |
| 521 | t, |
| 522 | clip_denoised=True, |
| 523 | denoised_fn=None, |
| 524 | cond_fn=None, |
| 525 | model_kwargs=None, |
| 526 | eta=0.0, |
| 527 | ): |
| 528 | """ |
| 529 | Sample x_{t-1} from the model using DDIM. |
| 530 | Same usage as p_sample(). |
| 531 | """ |
| 532 | out = self.p_mean_variance( |
| 533 | model, |
| 534 | x, |
| 535 | t, |
| 536 | clip_denoised=clip_denoised, |
| 537 | denoised_fn=denoised_fn, |
| 538 | model_kwargs=model_kwargs, |
| 539 | ) |
| 540 | if cond_fn is not None: |
| 541 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) |
| 542 | |
| 543 | # Usually our model outputs epsilon, but we re-derive it |
| 544 | # in case we used x_start or x_prev prediction. |
| 545 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) |
| 546 | |
| 547 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) |
| 548 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) |
| 549 | sigma = ( |
| 550 | eta |
| 551 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) |
| 552 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) |
| 553 | ) |
| 554 | # Equation 12. |
| 555 | noise = th.randn_like(x) |
| 556 | mean_pred = ( |
| 557 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) |
| 558 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps |
| 559 | ) |
| 560 | nonzero_mask = ( |
| 561 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) |
| 562 | ) # no noise when t == 0 |
| 563 | sample = mean_pred + nonzero_mask * sigma * noise |
| 564 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} |
| 565 | |
| 566 | def ddim_reverse_sample( |
| 567 | self, |
no test coverage detected