| 10 | |
| 11 | |
| 12 | class BertFeatExtractor(object): |
| 13 | def __init__(self, model_name): |
| 14 | self.tokenizer = BertTokenizer.from_pretrained(model_name) |
| 15 | self.model = BertModel.from_pretrained(model_name).eval() |
| 16 | self.model.cuda() |
| 17 | |
| 18 | def get_bert_embedding(self, text): |
| 19 | tokenized_text = self.tokenizer.tokenize(text) |
| 20 | tokenized_text = ["[CLS]"] + tokenized_text + ["[SEP]"] |
| 21 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) |
| 22 | tokens_tensor = torch.Tensor([indexed_tokens]).long() |
| 23 | segments_tensor = torch.Tensor([0] * len(tokenized_text)).long() |
| 24 | with torch.no_grad(): |
| 25 | encoded_layers, _ = self.model( |
| 26 | tokens_tensor.cuda(), |
| 27 | segments_tensor.cuda(), |
| 28 | output_all_encoded_layers=False, |
| 29 | ) |
| 30 | return encoded_layers.squeeze()[0] |
| 31 | |
| 32 | |
| 33 | def extract_bert(imdb_path, out_path, group_id=0, n_groups=1): |