(
model_path: str,
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype],
load_8bit: bool,
cpu_offloading: bool,
conv_template: Optional[str],
conv_system_msg: Optional[str],
temperature: float,
repetition_penalty: float,
max_new_tokens: int,
chatio: ChatIO,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
revision: str = "main",
judge_sent_end: bool = True,
debug: bool = True,
history: bool = True,
)
| 335 | |
| 336 | |
| 337 | def chat_loop( |
| 338 | model_path: str, |
| 339 | device: str, |
| 340 | num_gpus: int, |
| 341 | max_gpu_memory: str, |
| 342 | dtype: Optional[torch.dtype], |
| 343 | load_8bit: bool, |
| 344 | cpu_offloading: bool, |
| 345 | conv_template: Optional[str], |
| 346 | conv_system_msg: Optional[str], |
| 347 | temperature: float, |
| 348 | repetition_penalty: float, |
| 349 | max_new_tokens: int, |
| 350 | chatio: ChatIO, |
| 351 | gptq_config: Optional[GptqConfig] = None, |
| 352 | awq_config: Optional[AWQConfig] = None, |
| 353 | exllama_config: Optional[ExllamaConfig] = None, |
| 354 | xft_config: Optional[XftConfig] = None, |
| 355 | revision: str = "main", |
| 356 | judge_sent_end: bool = True, |
| 357 | debug: bool = True, |
| 358 | history: bool = True, |
| 359 | ): |
| 360 | # Model |
| 361 | model, tokenizer = load_model( |
| 362 | model_path, |
| 363 | device=device, |
| 364 | num_gpus=num_gpus, |
| 365 | max_gpu_memory=max_gpu_memory, |
| 366 | dtype=dtype, |
| 367 | load_8bit=load_8bit, |
| 368 | cpu_offloading=cpu_offloading, |
| 369 | gptq_config=gptq_config, |
| 370 | awq_config=awq_config, |
| 371 | exllama_config=exllama_config, |
| 372 | xft_config=xft_config, |
| 373 | revision=revision, |
| 374 | debug=debug, |
| 375 | ) |
| 376 | generate_stream_func = get_generate_stream_function(model, model_path) |
| 377 | |
| 378 | model_type = str(type(model)).lower() |
| 379 | is_t5 = "t5" in model_type |
| 380 | is_codet5p = "codet5p" in model_type |
| 381 | is_xft = "xft" in model_type |
| 382 | |
| 383 | # Hardcode T5's default repetition penalty to be 1.2 |
| 384 | if is_t5 and repetition_penalty == 1.0: |
| 385 | repetition_penalty = 1.2 |
| 386 | |
| 387 | # Set context length |
| 388 | context_len = get_context_length(model.config) |
| 389 | |
| 390 | # Chat |
| 391 | def new_chat(): |
| 392 | if conv_template: |
| 393 | conv = get_conv_template(conv_template) |
| 394 | else: |
no test coverage detected
searching dependent graphs…