()
| 30 | return model.half().eval().to(args.device), tokenizer |
| 31 | |
| 32 | def main(): |
| 33 | parser = argparse.ArgumentParser(description="MiniMind模型推理与对话") |
| 34 | parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)") |
| 35 | parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录") |
| 36 | parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)") |
| 37 | parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)") |
| 38 | parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度") |
| 39 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") |
| 40 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") |
| 41 | parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)") |
| 42 | parser.add_argument('--max_new_tokens', default=8192, type=int, help="最大生成长度(注意:并非模型实际长文本能力)") |
| 43 | parser.add_argument('--temperature', default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)") |
| 44 | parser.add_argument('--top_p', default=0.95, type=float, help="nucleus采样阈值(0-1)") |
| 45 | parser.add_argument('--open_thinking', default=0, type=int, help="是否开启自适应思考(0=否,1=是)") |
| 46 | parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数(需为偶数,0表示不携带历史)") |
| 47 | parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度(tokens/s)") |
| 48 | parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备") |
| 49 | args = parser.parse_args() |
| 50 | |
| 51 | prompts = [ |
| 52 | '你有什么特长?', |
| 53 | '为什么天空是蓝色的', |
| 54 | '请用Python写一个计算斐波那契数列的函数', |
| 55 | '解释一下"光合作用"的基本过程', |
| 56 | '如果明天下雨,我应该如何出门', |
| 57 | '比较一下猫和狗作为宠物的优缺点', |
| 58 | '解释什么是机器学习', |
| 59 | '推荐一些中国的美食' |
| 60 | ] |
| 61 | |
| 62 | conversation = [] |
| 63 | model, tokenizer = init_model(args) |
| 64 | input_mode = int(input('[0] 自动测试\n[1] 手动输入\n')) |
| 65 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| 66 | |
| 67 | prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '') |
| 68 | for prompt in prompt_iter: |
| 69 | setup_seed(random.randint(0, 31415926)) |
| 70 | if input_mode == 0: print(f'💬: {prompt}') |
| 71 | conversation = conversation[-args.historys:] if args.historys else [] |
| 72 | conversation.append({"role": "user", "content": prompt}) |
| 73 | if 'pretrain' in args.weight: |
| 74 | inputs = tokenizer.bos_token + prompt |
| 75 | else: |
| 76 | inputs = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, open_thinking=bool(args.open_thinking)) |
| 77 | |
| 78 | inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device) |
| 79 | |
| 80 | print('🧠: ', end='') |
| 81 | st = time.time() |
| 82 | generated_ids = model.generate( |
| 83 | inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"], |
| 84 | max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer, |
| 85 | pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, |
| 86 | top_p=args.top_p, temperature=args.temperature, repetition_penalty=1 |
| 87 | ) |
| 88 | response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) |
| 89 | conversation.append({"role": "assistant", "content": response}) |
no test coverage detected