MCPcopy
hub / github.com/facebookresearch/mmf / extract_bert

Function extract_bert

tools/bert/extract_bert_embeddings.py:33–48  ·  view source on GitHub ↗
(imdb_path, out_path, group_id=0, n_groups=1)

Source from the content-addressed store, hash-verified

31
32
33def extract_bert(imdb_path, out_path, group_id=0, n_groups=1):
34 imdb = np.load(imdb_path)
35
36 feat_extractor = BertFeatExtractor("bert-base-uncased")
37
38 if group_id == 0:
39 iterator_obj = tqdm(imdb[1:])
40 else:
41 iterator_obj = imdb[1:]
42
43 for idx, el in enumerate(iterator_obj):
44 if idx % n_groups != group_id:
45 continue
46 emb = feat_extractor.get_bert_embedding(el["question_str"])
47 save_path = out_path + str(el["question_id"])
48 np.save(save_path, emb.cpu().numpy())
49
50
51if __name__ == "__main__":

Callers 1

Calls 4

get_bert_embeddingMethod · 0.95
BertFeatExtractorClass · 0.85
saveMethod · 0.80
loadMethod · 0.45

Tested by

no test coverage detected