(
databunch,
pretrained_path,
device,
logger,
metrics=None,
finetuned_wgts_path=None,
multi_gpu=True,
is_fp16=True,
loss_scale=0,
warmup_steps=0,
fp16_opt_level="O1",
grad_accumulation_steps=1,
max_grad_norm=1.0,
adam_epsilon=1e-8,
logging_steps=100,
alpha=0.95,
beam_size=5,
min_length=50,
max_length=200,
block_trigram=True,
)
| 23 | class BertAbsLearner(Learner): |
| 24 | @staticmethod |
| 25 | def from_pretrained_model( |
| 26 | databunch, |
| 27 | pretrained_path, |
| 28 | device, |
| 29 | logger, |
| 30 | metrics=None, |
| 31 | finetuned_wgts_path=None, |
| 32 | multi_gpu=True, |
| 33 | is_fp16=True, |
| 34 | loss_scale=0, |
| 35 | warmup_steps=0, |
| 36 | fp16_opt_level="O1", |
| 37 | grad_accumulation_steps=1, |
| 38 | max_grad_norm=1.0, |
| 39 | adam_epsilon=1e-8, |
| 40 | logging_steps=100, |
| 41 | alpha=0.95, |
| 42 | beam_size=5, |
| 43 | min_length=50, |
| 44 | max_length=200, |
| 45 | block_trigram=True, |
| 46 | ): |
| 47 | |
| 48 | model_state_dict = None |
| 49 | |
| 50 | model_type = databunch.model_type |
| 51 | |
| 52 | config_class, model_class = MODEL_CLASSES[model_type] |
| 53 | |
| 54 | if torch.cuda.is_available(): |
| 55 | map_location = lambda storage, loc: storage.cuda() |
| 56 | else: |
| 57 | map_location = 'cpu' |
| 58 | |
| 59 | if finetuned_wgts_path: |
| 60 | model_state_dict = torch.load(finetuned_wgts_path, map_location=map_location) |
| 61 | else: |
| 62 | model_state_dict = None |
| 63 | |
| 64 | model = model_class.from_pretrained( |
| 65 | str(pretrained_path), state_dict=model_state_dict |
| 66 | ) |
| 67 | |
| 68 | model.to(device) |
| 69 | |
| 70 | return BertAbsLearner( |
| 71 | databunch, |
| 72 | model, |
| 73 | str(pretrained_path), |
| 74 | device, |
| 75 | logger, |
| 76 | metrics, |
| 77 | multi_gpu, |
| 78 | is_fp16, |
| 79 | loss_scale, |
| 80 | warmup_steps, |
| 81 | fp16_opt_level, |
| 82 | grad_accumulation_steps, |
nothing calls this directly
no test coverage detected