(self, img)
| 209 | |
| 210 | @torch.no_grad() |
| 211 | def get_codebook_indices(self, img): |
| 212 | b = img.shape[0] |
| 213 | img = (2 * img) - 1 |
| 214 | _, _, [_, _, indices] = self.model.encode(img) |
| 215 | if self.is_gumbel: |
| 216 | return rearrange(indices, 'b h w -> b (h w)', b=b) |
| 217 | return rearrange(indices, '(b n) -> b n', b = b) |
| 218 | |
| 219 | def decode(self, img_seq): |
| 220 | b, n = img_seq.shape |