MCPcopy
hub / github.com/ModelTC/LightLLM / test_model_inference

Function test_model_inference

test/benchmark/static_inference/model_infer.py:18–59  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

16
17
18def test_model_inference(args):
19 ans_queue = Queue()
20 workers = []
21 dp_size = args.get("dp", 1)
22
23 for rank_id in range(args.node_rank * args.tp, (args.node_rank + 1) * args.tp):
24 model_kvargs = {
25 "args": args,
26 "nccl_host": args.nccl_host,
27 "data_type": args.data_type,
28 "nccl_port": args.nccl_port,
29 "rank_id": rank_id,
30 "world_size": args.tp,
31 "dp_size": dp_size,
32 "weight_dir": args.model_dir,
33 "quant_type": args.quant_type,
34 "load_way": "HF",
35 "max_total_token_num": args.max_total_token_num,
36 "graph_max_len_in_batch": args.max_req_total_len,
37 "graph_max_batch_size": args.graph_max_batch_size,
38 "mem_fraction": args.mem_fraction,
39 "max_req_num": 2048,
40 "batch_max_tokens": 1024,
41 "run_mode": "normal",
42 "max_seq_length": args.max_req_total_len,
43 "disable_cudagraph": args.disable_cudagraph,
44 "mode": args.mode,
45 }
46 proc = multiprocessing.Process(
47 target=tppart_model_infer,
48 args=(args, model_kvargs, args.batch_size, args.input_len, args.output_len, ans_queue),
49 )
50 proc.start()
51 workers.append(proc)
52
53 for proc in workers:
54 proc.join()
55
56 assert not ans_queue.empty()
57 while not ans_queue.empty():
58 assert ans_queue.get()
59 return
60
61
62def overlap_prefill(

Callers 1

test_model_inferMethod · 0.90

Calls 3

startMethod · 0.80
emptyMethod · 0.80
getMethod · 0.45

Tested by

no test coverage detected