MCPcopy
hub / github.com/rosinality/vq-vae-2-pytorch / encode

Method encode

vqvae.py:204–221  ·  view source on GitHub ↗
(self, input)

Source from the content-addressed store, hash-verified

202 return dec, diff
203
204 def encode(self, input):
205 enc_b = self.enc_b(input)
206 enc_t = self.enc_t(enc_b)
207
208 quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
209 quant_t, diff_t, id_t = self.quantize_t(quant_t)
210 quant_t = quant_t.permute(0, 3, 1, 2)
211 diff_t = diff_t.unsqueeze(0)
212
213 dec_t = self.dec_t(quant_t)
214 enc_b = torch.cat([dec_t, enc_b], 1)
215
216 quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
217 quant_b, diff_b, id_b = self.quantize_b(quant_b)
218 quant_b = quant_b.permute(0, 3, 1, 2)
219 diff_b = diff_b.unsqueeze(0)
220
221 return quant_t, quant_b, diff_t + diff_b, id_t, id_b
222
223 def decode(self, quant_t, quant_b):
224 upsample_t = self.upsample_t(quant_t)

Callers 4

forwardMethod · 0.95
__init__Method · 0.80
__getitem__Method · 0.80
extractFunction · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected