| 241 | |
| 242 | @dataclass |
| 243 | class SamplingInputs: |
| 244 | temperature: torch.Tensor = None |
| 245 | bad_words: torch.LongTensor = None |
| 246 | bad_mask: torch.BoolTensor = None |
| 247 | stop_words: torch.LongTensor = None |
| 248 | stop_mask: torch.BoolTensor = None |
| 249 | repetition_penalty: torch.Tensor = None |
| 250 | top_k: torch.LongTensor = None |
| 251 | top_p: torch.Tensor = None |
| 252 | min_p: torch.Tensor = None |
| 253 | random_seeds: torch.Tensor = None |
| 254 | random_offsets: torch.Tensor = None |
| 255 | max_top_k: int = 1 |
| 256 | min_top_p: float = 1.0 |
| 257 | response_formats: list[str, ...] = () |
| 258 | logits_processors: list[list[LogitsProcessor]] = None |
| 259 | max_num_logprobs: None | int = None |
| 260 | all_ids: None | torch.Tensor = None |
| 261 | num_ignore_eos: torch.Tensor = None |
| 262 | batch_size: int = 0 |
| 263 | session_ctx: None | list[dict[str, Any]] = None |
| 264 | session_to_cleanup: None | list[int] = None |
| 265 | # for repetition_penalty and ngram |
| 266 | generated_ids: torch.Tensor | None = None |
| 267 | generated_ids_cpu: np.ndarray | None = None |
| 268 | |
| 269 | # n gram |
| 270 | repetition_ngram_size: torch.Tensor | None = None |
| 271 | repetition_ngram_threshold: torch.Tensor | None = None |
| 272 | max_repetition_ngram_size: int = 0 |
| 273 | |
| 274 | def to_device(self, device: str, non_blocking: bool = False): |
| 275 | """To device.""" |
| 276 | out_dict = dict() |
| 277 | if self.generated_ids is None and self.generated_ids_cpu is not None: |
| 278 | self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) |
| 279 | for f in fields(self): |
| 280 | k = f.name |
| 281 | v = getattr(self, k) |
| 282 | if isinstance(v, torch.Tensor): |
| 283 | v = v.to(device, non_blocking=non_blocking) |
| 284 | out_dict[k] = v |
| 285 | |
| 286 | return SamplingInputs(**out_dict) |
| 287 | |
| 288 | def get_delta(self) -> SamplingInputsDelta: |
| 289 | """Get delta.""" |
| 290 | delta = SamplingInputsDelta() |
| 291 | for f in fields(self): |
| 292 | k = f.name |
| 293 | v = getattr(self, k) |
| 294 | if isinstance(v, torch.Tensor): |
| 295 | setattr(delta, k, v) |
| 296 | return delta |
| 297 | |
| 298 | def update_delta(self, delta: SamplingInputsDelta): |
| 299 | """Update from delta.""" |
| 300 | for f in fields(delta): |
no outgoing calls