(databunch, pretrained_path, finetuned_wgts_path, device)
| 26 | |
| 27 | |
| 28 | def load_model(databunch, pretrained_path, finetuned_wgts_path, device): |
| 29 | |
| 30 | model_type = databunch.model_type |
| 31 | model_state_dict = None |
| 32 | |
| 33 | if torch.cuda.is_available(): |
| 34 | map_location = lambda storage, loc: storage.cuda() |
| 35 | else: |
| 36 | map_location = "cpu" |
| 37 | |
| 38 | if finetuned_wgts_path: |
| 39 | model_state_dict = torch.load(finetuned_wgts_path, map_location=map_location) |
| 40 | else: |
| 41 | model_state_dict = None |
| 42 | |
| 43 | config = AutoConfig.from_pretrained( |
| 44 | str(pretrained_path), |
| 45 | num_labels=len(databunch.labels), |
| 46 | model_type=model_type, |
| 47 | id2label=databunch.label_map, |
| 48 | label2id={label: i for i, label in enumerate(databunch.labels)}, |
| 49 | ) |
| 50 | |
| 51 | model = AutoModelForTokenClassification.from_pretrained( |
| 52 | str(pretrained_path), config=config, state_dict=model_state_dict |
| 53 | ) |
| 54 | |
| 55 | return model |
| 56 | |
| 57 | |
| 58 | class BertNERLearner(Learner): |
no test coverage detected