(self, input)
| 202 | return dec, diff |
| 203 | |
| 204 | def encode(self, input): |
| 205 | enc_b = self.enc_b(input) |
| 206 | enc_t = self.enc_t(enc_b) |
| 207 | |
| 208 | quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) |
| 209 | quant_t, diff_t, id_t = self.quantize_t(quant_t) |
| 210 | quant_t = quant_t.permute(0, 3, 1, 2) |
| 211 | diff_t = diff_t.unsqueeze(0) |
| 212 | |
| 213 | dec_t = self.dec_t(quant_t) |
| 214 | enc_b = torch.cat([dec_t, enc_b], 1) |
| 215 | |
| 216 | quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1) |
| 217 | quant_b, diff_b, id_b = self.quantize_b(quant_b) |
| 218 | quant_b = quant_b.permute(0, 3, 1, 2) |
| 219 | diff_b = diff_b.unsqueeze(0) |
| 220 | |
| 221 | return quant_t, quant_b, diff_t + diff_b, id_t, id_b |
| 222 | |
| 223 | def decode(self, quant_t, quant_b): |
| 224 | upsample_t = self.upsample_t(quant_t) |
no outgoing calls
no test coverage detected