MCPcopy
hub / github.com/lm-sys/FastChat / chat_loop

Function chat_loop

fastchat/serve/inference.py:337–555  ·  view source on GitHub ↗
(
    model_path: str,
    device: str,
    num_gpus: int,
    max_gpu_memory: str,
    dtype: Optional[torch.dtype],
    load_8bit: bool,
    cpu_offloading: bool,
    conv_template: Optional[str],
    conv_system_msg: Optional[str],
    temperature: float,
    repetition_penalty: float,
    max_new_tokens: int,
    chatio: ChatIO,
    gptq_config: Optional[GptqConfig] = None,
    awq_config: Optional[AWQConfig] = None,
    exllama_config: Optional[ExllamaConfig] = None,
    xft_config: Optional[XftConfig] = None,
    revision: str = "main",
    judge_sent_end: bool = True,
    debug: bool = True,
    history: bool = True,
)

Source from the content-addressed store, hash-verified

335
336
337def chat_loop(
338 model_path: str,
339 device: str,
340 num_gpus: int,
341 max_gpu_memory: str,
342 dtype: Optional[torch.dtype],
343 load_8bit: bool,
344 cpu_offloading: bool,
345 conv_template: Optional[str],
346 conv_system_msg: Optional[str],
347 temperature: float,
348 repetition_penalty: float,
349 max_new_tokens: int,
350 chatio: ChatIO,
351 gptq_config: Optional[GptqConfig] = None,
352 awq_config: Optional[AWQConfig] = None,
353 exllama_config: Optional[ExllamaConfig] = None,
354 xft_config: Optional[XftConfig] = None,
355 revision: str = "main",
356 judge_sent_end: bool = True,
357 debug: bool = True,
358 history: bool = True,
359):
360 # Model
361 model, tokenizer = load_model(
362 model_path,
363 device=device,
364 num_gpus=num_gpus,
365 max_gpu_memory=max_gpu_memory,
366 dtype=dtype,
367 load_8bit=load_8bit,
368 cpu_offloading=cpu_offloading,
369 gptq_config=gptq_config,
370 awq_config=awq_config,
371 exllama_config=exllama_config,
372 xft_config=xft_config,
373 revision=revision,
374 debug=debug,
375 )
376 generate_stream_func = get_generate_stream_function(model, model_path)
377
378 model_type = str(type(model)).lower()
379 is_t5 = "t5" in model_type
380 is_codet5p = "codet5p" in model_type
381 is_xft = "xft" in model_type
382
383 # Hardcode T5's default repetition penalty to be 1.2
384 if is_t5 and repetition_penalty == 1.0:
385 repetition_penalty = 1.2
386
387 # Set context length
388 context_len = get_context_length(model.config)
389
390 # Chat
391 def new_chat():
392 if conv_template:
393 conv = get_conv_template(conv_template)
394 else:

Callers 1

mainFunction · 0.90

Calls 14

load_modelFunction · 0.90
get_context_lengthFunction · 0.90
get_conv_templateFunction · 0.90
new_chatFunction · 0.85
reload_convFunction · 0.85
set_system_messageMethod · 0.80
append_messageMethod · 0.80
get_promptMethod · 0.80
update_last_messageMethod · 0.80
prompt_for_inputMethod · 0.45
dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…