only accepts latent-space data.
(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True)
| 349 | |
| 350 | @T.no_grad() |
| 351 | def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor: |
| 352 | """ |
| 353 | only accepts latent-space data. |
| 354 | """ |
| 355 | if use_cache: |
| 356 | self.init_cache(data.shape[0], data.device, T.bfloat16) |
| 357 | |
| 358 | next_input = generated = data |
| 359 | |
| 360 | target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len) |
| 361 | |
| 362 | for _ in tqdm0(range(data.shape[1], target_len)): |
| 363 | model_input = next_input if use_cache else generated |
| 364 | |
| 365 | next_input = self.next_latent(model_input, temps) |
| 366 | |
| 367 | generated = T.cat([generated, next_input], dim=1) |
| 368 | |
| 369 | if use_cache: |
| 370 | self.deinit_cache() |
| 371 | return generated |
| 372 | |
| 373 | |
| 374 |
nothing calls this directly
no test coverage detected