| 318 | _top_k: int = -1 # -1 is for all |
| 319 | |
| 320 | def init(self, tokenizer, **kwargs): |
| 321 | super().__init__() |
| 322 | self.best_of = kwargs.get("best_of", 1) |
| 323 | self.n = kwargs.get("n", self.best_of) |
| 324 | self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) |
| 325 | self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) |
| 326 | self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) |
| 327 | self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) |
| 328 | self.temperature = kwargs.get("temperature", SamplingParams._temperature) |
| 329 | self.top_p = kwargs.get("top_p", SamplingParams._top_p) |
| 330 | self.top_k = kwargs.get("top_k", SamplingParams._top_k) |
| 331 | self.ignore_eos = kwargs.get("ignore_eos", False) |
| 332 | self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) |
| 333 | self.max_new_tokens = kwargs.get("max_new_tokens", 16) |
| 334 | self.min_new_tokens = kwargs.get("min_new_tokens", 1) |
| 335 | self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) |
| 336 | self.group_request_id = kwargs.get("group_request_id", -1) |
| 337 | self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) |
| 338 | |
| 339 | self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) |
| 340 | |
| 341 | self.add_special_tokens = kwargs.get("add_special_tokens", True) |
| 342 | self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) |
| 343 | self.print_eos_token = kwargs.get("print_eos_token", False) |
| 344 | |
| 345 | self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() |
| 346 | self.exponential_decay_length_penalty.initialize(kwargs.get("exponential_decay_length_penalty", (1, 1.0))) |
| 347 | |
| 348 | self.move_kv_to_decode_node = DecodeNode() |
| 349 | self.move_kv_to_decode_node.initialize(kwargs.get("move_kv_to_decode_node", None)) |
| 350 | |
| 351 | # Initialize regular_constraint |
| 352 | regular_constraint = kwargs.get("regular_constraint", "") |
| 353 | self.regular_constraint = RegularConstraint() |
| 354 | self.regular_constraint.initialize(regular_constraint) |
| 355 | |
| 356 | # Initialize guided_grammar |
| 357 | guided_grammar = kwargs.get("guided_grammar", "") |
| 358 | self.guided_grammar = GuidedGrammar() |
| 359 | self.guided_grammar.initialize(guided_grammar, tokenizer) |
| 360 | |
| 361 | # Initialize guided_json |
| 362 | guided_json = kwargs.get("guided_json", "") |
| 363 | self.guided_json = GuidedJsonSchema() |
| 364 | self.guided_json.initialize(guided_json, tokenizer) |
| 365 | |
| 366 | # Initialize stop_sequence_groups |
| 367 | stop_sequences = kwargs.get("stop_sequences", []) |
| 368 | self.stop_sequences = StopSequenceGroups() |
| 369 | self.stop_sequences.initialize(stop_sequences, tokenizer) |
| 370 | |
| 371 | # Initialize allowed_token_ids |
| 372 | allowed_token_ids = kwargs.get("allowed_token_ids", []) |
| 373 | self.allowed_token_ids = AllowedTokenIds() |
| 374 | self.allowed_token_ids.initialize(allowed_token_ids) |
| 375 | |
| 376 | if self.do_sample is False: |
| 377 | self.temperature = 1.0 |