()
| 200 | |
| 201 | |
| 202 | def main(): |
| 203 | parser = argparse.ArgumentParser(description="MiniMind ToolCall评测") |
| 204 | parser.add_argument('--backend', default='local', choices=['local', 'api'], type=str, help="推理后端(local=本地模型,api=OpenAI兼容接口)") |
| 205 | parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)") |
| 206 | parser.add_argument('--save_dir', default='../out', type=str, help="模型权重目录") |
| 207 | parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)") |
| 208 | parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度") |
| 209 | parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") |
| 210 | parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") |
| 211 | parser.add_argument('--max_new_tokens', default=512, type=int, help="最大生成长度") |
| 212 | parser.add_argument('--temperature', default=0.9, type=float, help="生成温度,控制随机性(0-1,越大越随机)") |
| 213 | parser.add_argument('--top_p', default=0.9, type=float, help="nucleus采样阈值(0-1)") |
| 214 | parser.add_argument('--show_speed', default=0, type=int, help="显示decode速度(tokens/s)") |
| 215 | parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备") |
| 216 | parser.add_argument('--api_base_url', default="http://localhost:11434/v1", type=str, help="OpenAI兼容接口的base_url") |
| 217 | parser.add_argument('--api_key', default='sk-123', type=str, help="OpenAI兼容接口的api_key") |
| 218 | parser.add_argument('--api_model', default='jingyaogong/minimind-3:latest', type=str, help="API请求时使用的模型名称") |
| 219 | parser.add_argument('--stream', default=1, type=int, help="API模式下是否流式输出(0=否,1=是)") |
| 220 | args = parser.parse_args() |
| 221 | |
| 222 | model = tokenizer = client = None |
| 223 | if args.backend == 'local': model, tokenizer = init_model(args) |
| 224 | else: client = OpenAI(api_key=args.api_key, base_url=args.api_base_url) |
| 225 | |
| 226 | input_mode = int(input('[0] 自动测试\n[1] 手动输入\n')) |
| 227 | |
| 228 | cases = [{"prompt": case["prompt"], "tools": get_tools(case["tools"]), "tool_names": case["tools"]} for case in TEST_CASES] if input_mode == 0 else iter(lambda: {"prompt": input('💬: '), "tools": TOOLS, "tool_names": [t["function"]["name"] for t in TOOLS]}, {"prompt": "", "tools": TOOLS, "tool_names": []}) |
| 229 | for case in cases: |
| 230 | if not case["prompt"]: break |
| 231 | setup_seed(random.randint(0, 31415926)) |
| 232 | if input_mode == 0: |
| 233 | print(f'📦 可用工具: {case["tool_names"]}\n') |
| 234 | print(f'💬: {case["prompt"]}') |
| 235 | run_case(case["prompt"], case["tools"], args, model=model, tokenizer=tokenizer, client=client) |
| 236 | print('\n' + '-' * 50 + '\n') |
| 237 | |
| 238 | |
| 239 | if __name__ == "__main__": |
no test coverage detected