To device.
(self, device: str, non_blocking: bool = False)
| 272 | max_repetition_ngram_size: int = 0 |
| 273 | |
| 274 | def to_device(self, device: str, non_blocking: bool = False): |
| 275 | """To device.""" |
| 276 | out_dict = dict() |
| 277 | if self.generated_ids is None and self.generated_ids_cpu is not None: |
| 278 | self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) |
| 279 | for f in fields(self): |
| 280 | k = f.name |
| 281 | v = getattr(self, k) |
| 282 | if isinstance(v, torch.Tensor): |
| 283 | v = v.to(device, non_blocking=non_blocking) |
| 284 | out_dict[k] = v |
| 285 | |
| 286 | return SamplingInputs(**out_dict) |
| 287 | |
| 288 | def get_delta(self) -> SamplingInputsDelta: |
| 289 | """Get delta.""" |
no test coverage detected