MCPcopy
hub / github.com/hpcaitech/Open-Sora / run_inference

Function run_inference

tools/caption/caption_llava_next.py:226–370  ·  view source on GitHub ↗

Run inference on ActivityNet QA DataSet using the Video-ChatGPT model. Args: args: Command-line arguments.

(rank, world_size, args)

Source from the content-addressed store, hash-verified

224
225
226def run_inference(rank, world_size, args):
227 """
228 Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
229
230 Args:
231 args: Command-line arguments.
232 """
233 setup(rank, world_size)
234
235 device = torch.device(f"cuda:{rank}")
236 # Initialize the model
237 model_name = get_model_name_from_path(args.model_path)
238 # Set model configuration parameters if they exist
239 if args.overwrite == True:
240 overwrite_config = {}
241 overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode
242 overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
243 overwrite_config["mm_newline_position"] = args.mm_newline_position
244
245 cfg_pretrained = AutoConfig.from_pretrained(args.model_path)
246
247 # import pdb;pdb.set_trace()
248 if "qwen" not in args.model_path.lower():
249 if "224" in cfg_pretrained.mm_vision_tower:
250 # suppose the length of text tokens is around 1000, from bo's report
251 least_token_number = args.for_get_frames_num * (16 // args.mm_spatial_pool_stride) ** 2 + 1000
252 else:
253 least_token_number = args.for_get_frames_num * (24 // args.mm_spatial_pool_stride) ** 2 + 1000
254
255 scaling_factor = math.ceil(least_token_number / 4096)
256 if scaling_factor >= 2:
257 if "vicuna" in cfg_pretrained._name_or_path.lower():
258 print(float(scaling_factor))
259 overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
260 overwrite_config["max_sequence_length"] = 4096 * scaling_factor
261 overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor
262 if args.load_8bit:
263 quantization_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
264 tokenizer, model, image_processor, context_len = load_pretrained_model(
265 args.model_path,
266 args.model_base,
267 model_name,
268 device_map=device,
269 quantization_config=quantization_config,
270 overwrite_config=overwrite_config,
271 )
272 else:
273 tokenizer, model, image_processor, context_len = load_pretrained_model(
274 args.model_path, args.model_base, model_name, device_map=device, overwrite_config=overwrite_config
275 )
276 else:
277 tokenizer, model, image_processor, context_len = load_pretrained_model(
278 args.model_path, args.model_base, model_name, device_map=device
279 )
280
281 if tokenizer.pad_token_id is None:
282 if "qwen" in tokenizer.name_or_path.lower():
283 # print("Setting pad token to bos token for qwen model.")

Callers

nothing calls this directly

Calls 7

setupFunction · 0.85
create_dataloaderFunction · 0.85
tqdmFunction · 0.85
cleanupFunction · 0.85
from_pretrainedMethod · 0.80
toMethod · 0.80
deviceMethod · 0.45

Tested by

no test coverage detected