MCPcopy
hub / github.com/appvision-ai/fast-bert / load_model

Function load_model

fast_bert/learner_ner.py:28–55  ·  view source on GitHub ↗
(databunch, pretrained_path, finetuned_wgts_path, device)

Source from the content-addressed store, hash-verified

26
27
28def load_model(databunch, pretrained_path, finetuned_wgts_path, device):
29
30 model_type = databunch.model_type
31 model_state_dict = None
32
33 if torch.cuda.is_available():
34 map_location = lambda storage, loc: storage.cuda()
35 else:
36 map_location = "cpu"
37
38 if finetuned_wgts_path:
39 model_state_dict = torch.load(finetuned_wgts_path, map_location=map_location)
40 else:
41 model_state_dict = None
42
43 config = AutoConfig.from_pretrained(
44 str(pretrained_path),
45 num_labels=len(databunch.labels),
46 model_type=model_type,
47 id2label=databunch.label_map,
48 label2id={label: i for i, label in enumerate(databunch.labels)},
49 )
50
51 model = AutoModelForTokenClassification.from_pretrained(
52 str(pretrained_path), config=config, state_dict=model_state_dict
53 )
54
55 return model
56
57
58class BertNERLearner(Learner):

Callers 3

from_pretrained_modelMethod · 0.70
__init__Method · 0.70
__init__Method · 0.70

Calls 1

loadMethod · 0.80

Tested by

no test coverage detected