(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False)
| 38 | |
| 39 | # Only API model is accepted |
| 40 | def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False): |
| 41 | rank, world_size = get_rank_and_world_size() |
| 42 | assert rank == 0 and world_size == 1 |
| 43 | dataset_name = dataset.dataset_name |
| 44 | data = dataset.data |
| 45 | if index_set is not None: |
| 46 | data = data[data['index'].isin(index_set)] |
| 47 | |
| 48 | model = supported_VLM[model_name]() if isinstance(model, str) else model |
| 49 | assert getattr(model, 'is_api', False) |
| 50 | assert hasattr(model, 'chat_inner') |
| 51 | |
| 52 | lt, indices = len(data), list(data['index']) |
| 53 | structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)] |
| 54 | |
| 55 | out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl' |
| 56 | res = {} |
| 57 | if osp.exists(out_file): |
| 58 | res = load(out_file) |
| 59 | if ignore_failed: |
| 60 | res = {k: v for k, v in res.items() if FAIL_MSG not in v} |
| 61 | |
| 62 | structs = [s for i, s in zip(indices, structs) if i not in res] |
| 63 | indices = [i for i in indices if i not in res] |
| 64 | |
| 65 | structs = [dict(model=model, messages=struct, dataset_name=dataset_name) for struct in structs] |
| 66 | |
| 67 | if len(structs): |
| 68 | track_progress_rich(chat_mt, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices) |
| 69 | |
| 70 | res = load(out_file) |
| 71 | if index_set is not None: |
| 72 | res = {k: v for k, v in res.items() if k in index_set} |
| 73 | os.remove(out_file) |
| 74 | return res |
| 75 | |
| 76 | |
| 77 | def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4): |
no test coverage detected