(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4)
| 75 | |
| 76 | |
| 77 | def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4): |
| 78 | dataset_name = dataset.dataset_name |
| 79 | res = {} |
| 80 | if osp.exists(out_file): |
| 81 | res.update(load(out_file)) |
| 82 | |
| 83 | rank, world_size = get_rank_and_world_size() |
| 84 | sheet_indices = list(range(rank, len(dataset), world_size)) |
| 85 | lt = len(sheet_indices) |
| 86 | data = dataset.data.iloc[sheet_indices] |
| 87 | data_indices = [i for i in data['index']] |
| 88 | |
| 89 | # If finished, will exit without building the model |
| 90 | all_finished = True |
| 91 | for i in range(lt): |
| 92 | idx = data.iloc[i]['index'] |
| 93 | if idx not in res: |
| 94 | all_finished = False |
| 95 | if all_finished: |
| 96 | res = {k: res[k] for k in data_indices} |
| 97 | dump(res, out_file) |
| 98 | return |
| 99 | |
| 100 | # Data need to be inferred |
| 101 | data = data[~data['index'].isin(res)] |
| 102 | lt = len(data) |
| 103 | |
| 104 | model = supported_VLM[model_name]() if isinstance(model, str) else model |
| 105 | assert hasattr(model, 'chat_inner') |
| 106 | |
| 107 | is_api = getattr(model, 'is_api', False) |
| 108 | if is_api: |
| 109 | lt, indices = len(data), list(data['index']) |
| 110 | supp = infer_data_api( |
| 111 | model=model, |
| 112 | work_dir=work_dir, |
| 113 | model_name=model_name, |
| 114 | dataset=dataset, |
| 115 | index_set=set(indices), |
| 116 | api_nproc=api_nproc) |
| 117 | for idx in indices: |
| 118 | assert idx in supp |
| 119 | res.update(supp) |
| 120 | res = {k: res[k] for k in data_indices} |
| 121 | dump(res, out_file) |
| 122 | return model |
| 123 | else: |
| 124 | model.set_dump_image(dataset.dump_image) |
| 125 | |
| 126 | for i in tqdm(range(lt)): |
| 127 | idx = data.iloc[i]['index'] |
| 128 | if idx in res: |
| 129 | continue |
| 130 | |
| 131 | if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name): |
| 132 | struct = model.build_prompt(data.iloc[i], dataset=dataset_name) |
| 133 | else: |
| 134 | struct = dataset.build_prompt(data.iloc[i]) |
no test coverage detected