(
self,
target: nn.Module,
input_ids: torch.LongTensor,
max_new_tokens: int,
stop_token_ids: list[int],
temperature: float,
)
| 348 | |
| 349 | @torch.inference_mode() |
| 350 | def spec_generate( |
| 351 | self, |
| 352 | target: nn.Module, |
| 353 | input_ids: torch.LongTensor, |
| 354 | max_new_tokens: int, |
| 355 | stop_token_ids: list[int], |
| 356 | temperature: float, |
| 357 | ): |
| 358 | self.eval() |
| 359 | return dflash_generate( |
| 360 | self, |
| 361 | target=target, |
| 362 | input_ids=input_ids, |
| 363 | max_new_tokens=max_new_tokens, |
| 364 | stop_token_ids=stop_token_ids, |
| 365 | temperature=temperature, |
| 366 | ) |
nothing calls this directly
no test coverage detected