(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4)
| 49 | |
| 50 | |
| 51 | def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4): |
| 52 | res = load(out_file) if osp.exists(out_file) else {} |
| 53 | rank, world_size = get_rank_and_world_size() |
| 54 | dataset_name = dataset.dataset_name |
| 55 | |
| 56 | sample_indices = list(dataset.videos) if getattr(dataset, 'pack', False) else list(dataset.data['index']) |
| 57 | samples = list(dataset.videos) if getattr(dataset, 'pack', False) else list(range(len(dataset.data))) |
| 58 | sample_map = {i: s for i, s in zip(sample_indices, samples)} |
| 59 | |
| 60 | sample_indices_sub = sample_indices[rank::world_size] |
| 61 | if np.all([idx in res for idx in sample_indices_sub]): |
| 62 | return model |
| 63 | sample_indices_subrem = [x for x in sample_indices_sub if x not in res] |
| 64 | |
| 65 | model = supported_VLM[model_name]() if isinstance(model, str) else model |
| 66 | |
| 67 | is_api = getattr(model, 'is_api', False) |
| 68 | if is_api: |
| 69 | assert world_size == 1 |
| 70 | supp = infer_data_api( |
| 71 | model=model, |
| 72 | work_dir=work_dir, |
| 73 | model_name=model_name, |
| 74 | dataset=dataset, |
| 75 | samples_dict={k: sample_map[k] for k in sample_indices_subrem}, |
| 76 | api_nproc=api_nproc) |
| 77 | for k in sample_indices_subrem: |
| 78 | assert k in supp |
| 79 | res.update(supp) |
| 80 | dump(res, out_file) |
| 81 | return model |
| 82 | |
| 83 | assert not getattr(dataset, 'pack', False), 'Current model not supported pack mode!' |
| 84 | for i, idx in tqdm(enumerate(sample_indices_subrem)): |
| 85 | if idx in res: |
| 86 | continue |
| 87 | if getattr(model, 'nframe', None) is not None and getattr(model, 'nframe', 0) > 0: |
| 88 | if dataset.nframe > 0: |
| 89 | if getattr(model, 'nframe', 0) != dataset.nframe: |
| 90 | print(f'{model_name} is a video-llm model, nframe is set to {dataset.nframe}, not using default') |
| 91 | setattr(model, 'nframe', dataset.nframe) |
| 92 | elif getattr(model, 'fps', 0) == 0: |
| 93 | raise ValueError(f'fps is not suitable for {model_name}') |
| 94 | else: |
| 95 | setattr(model, 'nframe', None) |
| 96 | if getattr(model, 'fps', None) is not None and getattr(model, 'fps', 0) > 0: |
| 97 | if dataset.fps > 0: |
| 98 | if getattr(model, 'fps', 0) != dataset.fps: |
| 99 | print(f'{model_name} is a video-llm model, fps is set to {dataset.fps}, not using default') |
| 100 | setattr(model, 'fps', dataset.fps) |
| 101 | elif getattr(model, 'nframe', 0) == 0: |
| 102 | raise ValueError(f'nframe is not suitable for {model_name}') |
| 103 | else: |
| 104 | setattr(model, 'fps', None) |
| 105 | if 'SUB_DATASET' in dataset.data.iloc[sample_map[idx]]: |
| 106 | dataset_name = dataset.data.iloc[sample_map[idx]]['SUB_DATASET'] |
| 107 | if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name): |
| 108 | if dataset.nframe == 0: |
no test coverage detected