(self, pixel_values, return_loss=False)
| 550 | ) |
| 551 | |
| 552 | def encode(self, pixel_values, return_loss=False): |
| 553 | hidden_states = self.encoder(pixel_values) |
| 554 | hidden_states = self.quant_conv(hidden_states) |
| 555 | quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss) |
| 556 | output = (quantized_states, codebook_indices) |
| 557 | if return_loss: |
| 558 | output = output + (codebook_loss,) |
| 559 | return output |
| 560 | |
| 561 | def decode(self, quantized_states): |
| 562 | hidden_states = self.post_quant_conv(quantized_states) |