(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
)
| 43 | |
| 44 | |
| 45 | def prepare_logits_processor( |
| 46 | temperature: float, repetition_penalty: float, top_p: float, top_k: int |
| 47 | ) -> LogitsProcessorList: |
| 48 | processor_list = LogitsProcessorList() |
| 49 | # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. |
| 50 | if temperature >= 1e-5 and temperature != 1.0: |
| 51 | processor_list.append(TemperatureLogitsWarper(temperature)) |
| 52 | if repetition_penalty > 1.0: |
| 53 | processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) |
| 54 | if 1e-8 <= top_p < 1.0: |
| 55 | processor_list.append(TopPLogitsWarper(top_p)) |
| 56 | if top_k > 0: |
| 57 | processor_list.append(TopKLogitsWarper(top_k)) |
| 58 | return processor_list |
| 59 | |
| 60 | |
| 61 | @torch.inference_mode() |
no outgoing calls
no test coverage detected
searching dependent graphs…