(args)
| 190 | |
| 191 | |
| 192 | def main(args): |
| 193 | if args.gpus: |
| 194 | if len(args.gpus.split(",")) < args.num_gpus: |
| 195 | raise ValueError( |
| 196 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" |
| 197 | ) |
| 198 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus |
| 199 | os.environ["XPU_VISIBLE_DEVICES"] = args.gpus |
| 200 | if args.enable_exllama: |
| 201 | exllama_config = ExllamaConfig( |
| 202 | max_seq_len=args.exllama_max_seq_len, |
| 203 | gpu_split=args.exllama_gpu_split, |
| 204 | cache_8bit=args.exllama_cache_8bit, |
| 205 | ) |
| 206 | else: |
| 207 | exllama_config = None |
| 208 | if args.enable_xft: |
| 209 | xft_config = XftConfig( |
| 210 | max_seq_len=args.xft_max_seq_len, |
| 211 | data_type=args.xft_dtype, |
| 212 | ) |
| 213 | if args.device != "cpu": |
| 214 | print("xFasterTransformer now is only support CPUs. Reset device to CPU") |
| 215 | args.device = "cpu" |
| 216 | else: |
| 217 | xft_config = None |
| 218 | if args.style == "simple": |
| 219 | chatio = SimpleChatIO(args.multiline) |
| 220 | elif args.style == "rich": |
| 221 | chatio = RichChatIO(args.multiline, args.mouse) |
| 222 | elif args.style == "programmatic": |
| 223 | chatio = ProgrammaticChatIO() |
| 224 | else: |
| 225 | raise ValueError(f"Invalid style for console: {args.style}") |
| 226 | try: |
| 227 | chat_loop( |
| 228 | args.model_path, |
| 229 | args.device, |
| 230 | args.num_gpus, |
| 231 | args.max_gpu_memory, |
| 232 | str_to_torch_dtype(args.dtype), |
| 233 | args.load_8bit, |
| 234 | args.cpu_offloading, |
| 235 | args.conv_template, |
| 236 | args.conv_system_msg, |
| 237 | args.temperature, |
| 238 | args.repetition_penalty, |
| 239 | args.max_new_tokens, |
| 240 | chatio, |
| 241 | gptq_config=GptqConfig( |
| 242 | ckpt=args.gptq_ckpt or args.model_path, |
| 243 | wbits=args.gptq_wbits, |
| 244 | groupsize=args.gptq_groupsize, |
| 245 | act_order=args.gptq_act_order, |
| 246 | ), |
| 247 | awq_config=AWQConfig( |
| 248 | ckpt=args.awq_ckpt or args.model_path, |
| 249 | wbits=args.awq_wbits, |
no test coverage detected
searching dependent graphs…