| 90 | return logits |
| 91 | |
| 92 | class BertMultiTask: |
| 93 | def __init__(self, args): |
| 94 | self.config = args.config |
| 95 | |
| 96 | if not args.use_pretrain: |
| 97 | |
| 98 | bert_config = BertConfig(**self.config["bert_model_config"]) |
| 99 | bert_config.vocab_size = len(args.tokenizer.vocab) |
| 100 | |
| 101 | # Padding for divisibility by 8 |
| 102 | if bert_config.vocab_size % 8 != 0: |
| 103 | bert_config.vocab_size += 8 - (bert_config.vocab_size % 8) |
| 104 | print("VOCAB SIZE:", bert_config.vocab_size) |
| 105 | |
| 106 | self.network = BertForPreTraining(bert_config, args) |
| 107 | # Use pretrained bert weights |
| 108 | else: |
| 109 | self.bert_encoder = BertModel.from_pretrained(self.config['bert_model_file'], cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) |
| 110 | bert_config = self.bert_encoder.config |
| 111 | |
| 112 | self.device = None |
| 113 | |
| 114 | def set_device(self, device): |
| 115 | self.device = device |
| 116 | |
| 117 | def save(self, filename: str): |
| 118 | network=self.network.module |
| 119 | return torch.save(network.state_dict(), filename) |
| 120 | |
| 121 | def load(self, model_state_dict: str): |
| 122 | return self.network.module.load_state_dict(torch.load(model_state_dict, map_location=lambda storage, loc: storage)) |
| 123 | |
| 124 | def move_batch(self, batch: TorchTuple, non_blocking=False): |
| 125 | return batch.to(self.device, non_blocking) |
| 126 | |
| 127 | def eval(self): |
| 128 | self.network.eval() |
| 129 | |
| 130 | def train(self): |
| 131 | self.network.train() |
| 132 | |
| 133 | def save_bert(self, filename: str): |
| 134 | return torch.save(self.bert_encoder.state_dict(), filename) |
| 135 | |
| 136 | def to(self, device): |
| 137 | assert isinstance(device, torch.device) |
| 138 | self.network.to(device) |
| 139 | |
| 140 | def half(self): |
| 141 | self.network.half() |
no outgoing calls
no test coverage detected