MCPcopy
hub / github.com/FoundationVision/LlamaGen / get_codebook_entry

Method get_codebook_entry

tokenizer/vqgan/quantize.py:92–107  ·  view source on GitHub ↗
(self, indices, shape)

Source from the content-addressed store, hash-verified

90 return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
92 def get_codebook_entry(self, indices, shape):
93 # shape specifying (batch, height, width, channel)
94 # TODO: check for more easy handling with nn.Embedding
95 min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96 min_encodings.scatter_(1, indices[:,None], 1)
97
98 # get quantized latent vectors
99 z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
101 if shape is not None:
102 z_q = z_q.view(shape)
103
104 # reshape back to match original input shape
105 z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
107 return z_q
108
109
110class VectorQuantizer2(nn.Module):

Callers 1

decode_codeMethod · 0.45

Calls

no outgoing calls

Tested by

no test coverage detected