MCPcopy
hub / github.com/facebookresearch/mmf / BertFeatExtractor

Class BertFeatExtractor

tools/bert/extract_bert_embeddings.py:12–30  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

10
11
12class BertFeatExtractor(object):
13 def __init__(self, model_name):
14 self.tokenizer = BertTokenizer.from_pretrained(model_name)
15 self.model = BertModel.from_pretrained(model_name).eval()
16 self.model.cuda()
17
18 def get_bert_embedding(self, text):
19 tokenized_text = self.tokenizer.tokenize(text)
20 tokenized_text = ["[CLS]"] + tokenized_text + ["[SEP]"]
21 indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
22 tokens_tensor = torch.Tensor([indexed_tokens]).long()
23 segments_tensor = torch.Tensor([0] * len(tokenized_text)).long()
24 with torch.no_grad():
25 encoded_layers, _ = self.model(
26 tokens_tensor.cuda(),
27 segments_tensor.cuda(),
28 output_all_encoded_layers=False,
29 )
30 return encoded_layers.squeeze()[0]
31
32
33def extract_bert(imdb_path, out_path, group_id=0, n_groups=1):

Callers 1

extract_bertFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected