(imdb_path, out_path, group_id=0, n_groups=1)
| 31 | |
| 32 | |
| 33 | def extract_bert(imdb_path, out_path, group_id=0, n_groups=1): |
| 34 | imdb = np.load(imdb_path) |
| 35 | |
| 36 | feat_extractor = BertFeatExtractor("bert-base-uncased") |
| 37 | |
| 38 | if group_id == 0: |
| 39 | iterator_obj = tqdm(imdb[1:]) |
| 40 | else: |
| 41 | iterator_obj = imdb[1:] |
| 42 | |
| 43 | for idx, el in enumerate(iterator_obj): |
| 44 | if idx % n_groups != group_id: |
| 45 | continue |
| 46 | emb = feat_extractor.get_bert_embedding(el["question_str"]) |
| 47 | save_path = out_path + str(el["question_id"]) |
| 48 | np.save(save_path, emb.cpu().numpy()) |
| 49 | |
| 50 | |
| 51 | if __name__ == "__main__": |
no test coverage detected