(model, work_dir, model_name, dataset, samples_dict={}, api_nproc=4)
| 19 | |
| 20 | # Only API model is accepted |
| 21 | def infer_data_api(model, work_dir, model_name, dataset, samples_dict={}, api_nproc=4): |
| 22 | rank, world_size = get_rank_and_world_size() |
| 23 | assert rank == 0 and world_size == 1 |
| 24 | dataset_name = dataset.dataset_name |
| 25 | model = supported_VLM[model_name]() if isinstance(model, str) else model |
| 26 | assert getattr(model, 'is_api', False) |
| 27 | |
| 28 | indices = list(samples_dict.keys()) |
| 29 | structs = [dataset.build_prompt(samples_dict[idx], video_llm=getattr(model, 'VIDEO_LLM', False)) for idx in indices] |
| 30 | |
| 31 | packstr = 'pack' if getattr(dataset, 'pack', False) else 'nopack' |
| 32 | if dataset.nframe > 0: |
| 33 | out_file = f'{work_dir}/{model_name}_{dataset_name}_{dataset.nframe}frame_{packstr}_supp.pkl' |
| 34 | else: |
| 35 | out_file = f'{work_dir}/{model_name}_{dataset_name}_{dataset.fps}fps_{packstr}_supp.pkl' |
| 36 | res = load(out_file) if osp.exists(out_file) else {} |
| 37 | |
| 38 | structs = [s for i, s in zip(indices, structs) if i not in res or res[i] == FAIL_MSG] |
| 39 | indices = [i for i in indices if i not in res or res[i] == FAIL_MSG] |
| 40 | |
| 41 | gen_func = model.generate |
| 42 | structs = [dict(message=struct, dataset=dataset_name) for struct in structs] |
| 43 | |
| 44 | if len(structs): |
| 45 | track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices) |
| 46 | |
| 47 | res = load(out_file) |
| 48 | return res |
| 49 | |
| 50 | |
| 51 | def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4): |
no test coverage detected