MCPcopy
hub / github.com/lm-sys/FastChat / get_generate_stream_function

Function get_generate_stream_function

fastchat/model/model_adapter.py:384–459  ·  view source on GitHub ↗

Get the generate_stream function for inference.

(model: torch.nn.Module, model_path: str)

Source from the content-addressed store, hash-verified

382
383
384def get_generate_stream_function(model: torch.nn.Module, model_path: str):
385 """Get the generate_stream function for inference."""
386 from fastchat.serve.inference import generate_stream
387
388 model_type = str(type(model)).lower()
389 is_peft = "peft" in model_type
390 is_chatglm = "chatglm" in model_type
391 is_falcon = "rwforcausallm" in model_type
392 is_codet5p = "codet5p" in model_type
393 is_exllama = "exllama" in model_type
394 is_xft = "xft" in model_type
395 is_yuan = "yuan" in model_type
396
397 if is_chatglm:
398 return generate_stream_chatglm
399 elif is_falcon:
400 return generate_stream_falcon
401 elif is_codet5p:
402 return generate_stream_codet5p
403 elif is_exllama:
404 return generate_stream_exllama
405 elif is_xft:
406 return generate_stream_xft
407 elif is_yuan:
408 return generate_stream_yuan2
409
410 elif peft_share_base_weights and is_peft:
411 # Return a curried stream function that loads the right adapter
412 # according to the model_name available in this context. This ensures
413 # the right weights are available.
414 @torch.inference_mode()
415 def generate_stream_peft(
416 model,
417 tokenizer,
418 params: Dict,
419 device: str,
420 context_len: int,
421 stream_interval: int = 2,
422 judge_sent_end: bool = False,
423 ):
424 model.set_adapter(model_path)
425 base_model_type = str(type(model.base_model.model))
426 is_chatglm = "chatglm" in base_model_type
427 is_falcon = "rwforcausallm" in base_model_type
428 is_codet5p = "codet5p" in base_model_type
429 is_exllama = "exllama" in base_model_type
430 is_xft = "xft" in base_model_type
431 is_yuan = "yuan" in base_model_type
432
433 generate_stream_function = generate_stream
434 if is_chatglm:
435 generate_stream_function = generate_stream_chatglm
436 elif is_falcon:
437 generate_stream_function = generate_stream_falcon
438 elif is_codet5p:
439 generate_stream_function = generate_stream_codet5p
440 elif is_exllama:
441 generate_stream_function = generate_stream_exllama

Callers 2

__init__Method · 0.90
chat_loopFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…