MCPcopy
hub / github.com/lucidrains/DALLE-pytorch / get_codebook_indices

Method get_codebook_indices

dalle_pytorch/vae.py:211–217  ·  view source on GitHub ↗
(self, img)

Source from the content-addressed store, hash-verified

209
210 @torch.no_grad()
211 def get_codebook_indices(self, img):
212 b = img.shape[0]
213 img = (2 * img) - 1
214 _, _, [_, _, indices] = self.model.encode(img)
215 if self.is_gumbel:
216 return rearrange(indices, 'b h w -> b (h w)', b=b)
217 return rearrange(indices, '(b n) -> b n', b = b)
218
219 def decode(self, img_seq):
220 b, n = img_seq.shape

Callers 1

train_vae.pyFile · 0.45

Calls 1

encodeMethod · 0.45

Tested by

no test coverage detected