x: (B, S, D)
(self, x, return_latent=False, known_latent=None)
| 42 | |
| 43 | @T.no_grad() |
| 44 | def forward(self, x, return_latent=False, known_latent=None): |
| 45 | """ |
| 46 | x: (B, S, D) |
| 47 | """ |
| 48 | if exists(known_latent): |
| 49 | return self.compressor.indices_to_codes(known_latent) |
| 50 | |
| 51 | x = self.input(x) |
| 52 | x = self.ffnn(x) |
| 53 | x, tokens = self.compressor(x) |
| 54 | |
| 55 | if return_latent: |
| 56 | return x, tokens |
| 57 | return x |
| 58 | |
| 59 | |
| 60 | @si_module |
nothing calls this directly
no test coverage detected