MCPcopy
hub / github.com/InternLM/lmdeploy / SamplingInputs

Class SamplingInputs

lmdeploy/pytorch/engine/logits_process.py:243–304  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

241
242@dataclass
243class SamplingInputs:
244 temperature: torch.Tensor = None
245 bad_words: torch.LongTensor = None
246 bad_mask: torch.BoolTensor = None
247 stop_words: torch.LongTensor = None
248 stop_mask: torch.BoolTensor = None
249 repetition_penalty: torch.Tensor = None
250 top_k: torch.LongTensor = None
251 top_p: torch.Tensor = None
252 min_p: torch.Tensor = None
253 random_seeds: torch.Tensor = None
254 random_offsets: torch.Tensor = None
255 max_top_k: int = 1
256 min_top_p: float = 1.0
257 response_formats: list[str, ...] = ()
258 logits_processors: list[list[LogitsProcessor]] = None
259 max_num_logprobs: None | int = None
260 all_ids: None | torch.Tensor = None
261 num_ignore_eos: torch.Tensor = None
262 batch_size: int = 0
263 session_ctx: None | list[dict[str, Any]] = None
264 session_to_cleanup: None | list[int] = None
265 # for repetition_penalty and ngram
266 generated_ids: torch.Tensor | None = None
267 generated_ids_cpu: np.ndarray | None = None
268
269 # n gram
270 repetition_ngram_size: torch.Tensor | None = None
271 repetition_ngram_threshold: torch.Tensor | None = None
272 max_repetition_ngram_size: int = 0
273
274 def to_device(self, device: str, non_blocking: bool = False):
275 """To device."""
276 out_dict = dict()
277 if self.generated_ids is None and self.generated_ids_cpu is not None:
278 self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy())
279 for f in fields(self):
280 k = f.name
281 v = getattr(self, k)
282 if isinstance(v, torch.Tensor):
283 v = v.to(device, non_blocking=non_blocking)
284 out_dict[k] = v
285
286 return SamplingInputs(**out_dict)
287
288 def get_delta(self) -> SamplingInputsDelta:
289 """Get delta."""
290 delta = SamplingInputsDelta()
291 for f in fields(self):
292 k = f.name
293 v = getattr(self, k)
294 if isinstance(v, torch.Tensor):
295 setattr(delta, k, v)
296 return delta
297
298 def update_delta(self, delta: SamplingInputsDelta):
299 """Update from delta."""
300 for f in fields(delta):

Calls

no outgoing calls