(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1)))
| 328 | |
| 329 | @T.no_grad() |
| 330 | def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))): |
| 331 | |
| 332 | if self.c.split: |
| 333 | logits1, logits2 = self.forward(model_input) |
| 334 | next_logits1 = logits1[:, -1] |
| 335 | next_logits2 = logits2[:, -1] |
| 336 | next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1) |
| 337 | next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1) |
| 338 | |
| 339 | next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1]) |
| 340 | else: |
| 341 | logits = self.forward(model_input) |
| 342 | next_logits = logits[:, -1] |
| 343 | next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1) |
| 344 | |
| 345 | next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1]) |
| 346 | |
| 347 | return next_input |
| 348 | |
| 349 | |
| 350 | @T.no_grad() |
no test coverage detected