Run inference on ActivityNet QA DataSet using the Video-ChatGPT model. Args: args: Command-line arguments.
(rank, world_size, args)
| 224 | |
| 225 | |
| 226 | def run_inference(rank, world_size, args): |
| 227 | """ |
| 228 | Run inference on ActivityNet QA DataSet using the Video-ChatGPT model. |
| 229 | |
| 230 | Args: |
| 231 | args: Command-line arguments. |
| 232 | """ |
| 233 | setup(rank, world_size) |
| 234 | |
| 235 | device = torch.device(f"cuda:{rank}") |
| 236 | # Initialize the model |
| 237 | model_name = get_model_name_from_path(args.model_path) |
| 238 | # Set model configuration parameters if they exist |
| 239 | if args.overwrite == True: |
| 240 | overwrite_config = {} |
| 241 | overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode |
| 242 | overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride |
| 243 | overwrite_config["mm_newline_position"] = args.mm_newline_position |
| 244 | |
| 245 | cfg_pretrained = AutoConfig.from_pretrained(args.model_path) |
| 246 | |
| 247 | # import pdb;pdb.set_trace() |
| 248 | if "qwen" not in args.model_path.lower(): |
| 249 | if "224" in cfg_pretrained.mm_vision_tower: |
| 250 | # suppose the length of text tokens is around 1000, from bo's report |
| 251 | least_token_number = args.for_get_frames_num * (16 // args.mm_spatial_pool_stride) ** 2 + 1000 |
| 252 | else: |
| 253 | least_token_number = args.for_get_frames_num * (24 // args.mm_spatial_pool_stride) ** 2 + 1000 |
| 254 | |
| 255 | scaling_factor = math.ceil(least_token_number / 4096) |
| 256 | if scaling_factor >= 2: |
| 257 | if "vicuna" in cfg_pretrained._name_or_path.lower(): |
| 258 | print(float(scaling_factor)) |
| 259 | overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"} |
| 260 | overwrite_config["max_sequence_length"] = 4096 * scaling_factor |
| 261 | overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor |
| 262 | if args.load_8bit: |
| 263 | quantization_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16) |
| 264 | tokenizer, model, image_processor, context_len = load_pretrained_model( |
| 265 | args.model_path, |
| 266 | args.model_base, |
| 267 | model_name, |
| 268 | device_map=device, |
| 269 | quantization_config=quantization_config, |
| 270 | overwrite_config=overwrite_config, |
| 271 | ) |
| 272 | else: |
| 273 | tokenizer, model, image_processor, context_len = load_pretrained_model( |
| 274 | args.model_path, args.model_base, model_name, device_map=device, overwrite_config=overwrite_config |
| 275 | ) |
| 276 | else: |
| 277 | tokenizer, model, image_processor, context_len = load_pretrained_model( |
| 278 | args.model_path, args.model_base, model_name, device_map=device |
| 279 | ) |
| 280 | |
| 281 | if tokenizer.pad_token_id is None: |
| 282 | if "qwen" in tokenizer.name_or_path.lower(): |
| 283 | # print("Setting pad token to bos token for qwen model.") |
nothing calls this directly
no test coverage detected