MCPcopy
hub / github.com/hpcaitech/ColossalAI / benchmark_inference

Function benchmark_inference

examples/inference/llama/benchmark_llama3.py:111–191  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

109
110
111def benchmark_inference(args):
112 coordinator = DistCoordinator()
113
114 torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None)
115 config = CONFIG_MAP[args.model]
116 config.torch_dtype = torch_dtype
117 config.pad_token_id = config.eos_token_id
118
119 if args.model_path is not None:
120 model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype)
121 tokenizer = AutoTokenizer.from_pretrained(args.model_path)
122 else:
123 # Random weights
124 model = transformers.LlamaForCausalLM(config)
125 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
126 if args.dtype == "fp16":
127 model = model.half()
128 elif args.dtype == "bf16":
129 model = model.to(torch.bfloat16)
130
131 inference_config = InferenceConfig(
132 dtype=args.dtype,
133 max_batch_size=args.batch_size,
134 max_input_len=args.max_seq_len,
135 max_output_len=args.max_output_len,
136 prefill_ratio=1.2,
137 block_size=32,
138 tp_size=args.tp_size,
139 use_cuda_kernel=True,
140 )
141 engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
142
143 data = data_gen(args.batch_size, args.max_seq_len)
144 generation_config = GenerationConfig(
145 pad_token_id=tokenizer.pad_token_id,
146 max_length=args.max_seq_len + args.max_output_len,
147 # max_new_tokens=args.max_output_len,
148 )
149 coordinator.print_on_master(f"Generation Config: \n{generation_config.to_dict()}")
150
151 ctx = (
152 torch.profiler.profile(
153 record_shapes=True,
154 with_stack=True,
155 with_modules=True,
156 activities=[
157 torch.profiler.ProfilerActivity.CPU,
158 torch.profiler.ProfilerActivity.CUDA,
159 ],
160 schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
161 on_trace_ready=torch.profiler.tensorboard_trace_handler(
162 f"./tb_log_{args.batch_size}_{args.max_seq_len}_{args.max_output_len}"
163 ),
164 )
165 if args.profile
166 else nullcontext()
167 )
168 with ctx:

Callers 1

inferenceFunction · 0.70

Calls 15

print_on_masterMethod · 0.95
generateMethod · 0.95
DistCoordinatorClass · 0.90
InferenceConfigClass · 0.90
InferenceEngineClass · 0.90
halfMethod · 0.80
data_genFunction · 0.70
print_details_infoFunction · 0.70
getMethod · 0.45
from_pretrainedMethod · 0.45
toMethod · 0.45
to_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…