(canvas, tokens, ids, H, W)
| 817 | |
| 818 | # Stash the latents and the complement condition |
| 819 | def _stash(canvas, tokens, ids, H, W) -> None: |
| 820 | B, T, C = tokens.shape |
| 821 | ids = ids.to(torch.long) |
| 822 | flat_idx = (ids[:, 1] * W + ids[:, 2]).to(torch.long) |
| 823 | canvas.view(B, -1, C).index_copy_(1, flat_idx, tokens) |
| 824 | |
| 825 | _stash(canvas, latents, latent_image_ids, H, W) |
| 826 | _stash(canvas, comp_latent, comp_ids, H, W) |