(args)
| 109 | |
| 110 | |
| 111 | def benchmark_inference(args): |
| 112 | coordinator = DistCoordinator() |
| 113 | |
| 114 | torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None) |
| 115 | config = CONFIG_MAP[args.model] |
| 116 | config.torch_dtype = torch_dtype |
| 117 | config.pad_token_id = config.eos_token_id |
| 118 | |
| 119 | if args.model_path is not None: |
| 120 | model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype) |
| 121 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| 122 | else: |
| 123 | # Random weights |
| 124 | model = transformers.LlamaForCausalLM(config) |
| 125 | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") |
| 126 | if args.dtype == "fp16": |
| 127 | model = model.half() |
| 128 | elif args.dtype == "bf16": |
| 129 | model = model.to(torch.bfloat16) |
| 130 | |
| 131 | inference_config = InferenceConfig( |
| 132 | dtype=args.dtype, |
| 133 | max_batch_size=args.batch_size, |
| 134 | max_input_len=args.max_seq_len, |
| 135 | max_output_len=args.max_output_len, |
| 136 | prefill_ratio=1.2, |
| 137 | block_size=32, |
| 138 | tp_size=args.tp_size, |
| 139 | use_cuda_kernel=True, |
| 140 | ) |
| 141 | engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) |
| 142 | |
| 143 | data = data_gen(args.batch_size, args.max_seq_len) |
| 144 | generation_config = GenerationConfig( |
| 145 | pad_token_id=tokenizer.pad_token_id, |
| 146 | max_length=args.max_seq_len + args.max_output_len, |
| 147 | # max_new_tokens=args.max_output_len, |
| 148 | ) |
| 149 | coordinator.print_on_master(f"Generation Config: \n{generation_config.to_dict()}") |
| 150 | |
| 151 | ctx = ( |
| 152 | torch.profiler.profile( |
| 153 | record_shapes=True, |
| 154 | with_stack=True, |
| 155 | with_modules=True, |
| 156 | activities=[ |
| 157 | torch.profiler.ProfilerActivity.CPU, |
| 158 | torch.profiler.ProfilerActivity.CUDA, |
| 159 | ], |
| 160 | schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1), |
| 161 | on_trace_ready=torch.profiler.tensorboard_trace_handler( |
| 162 | f"./tb_log_{args.batch_size}_{args.max_seq_len}_{args.max_output_len}" |
| 163 | ), |
| 164 | ) |
| 165 | if args.profile |
| 166 | else nullcontext() |
| 167 | ) |
| 168 | with ctx: |
no test coverage detected
searching dependent graphs…