MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeedExamples / BertMultiTask

Class BertMultiTask

bing_bert/turing/models.py:92–141  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

90 return logits
91
92class BertMultiTask:
93 def __init__(self, args):
94 self.config = args.config
95
96 if not args.use_pretrain:
97
98 bert_config = BertConfig(**self.config["bert_model_config"])
99 bert_config.vocab_size = len(args.tokenizer.vocab)
100
101 # Padding for divisibility by 8
102 if bert_config.vocab_size % 8 != 0:
103 bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
104 print("VOCAB SIZE:", bert_config.vocab_size)
105
106 self.network = BertForPreTraining(bert_config, args)
107 # Use pretrained bert weights
108 else:
109 self.bert_encoder = BertModel.from_pretrained(self.config['bert_model_file'], cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
110 bert_config = self.bert_encoder.config
111
112 self.device = None
113
114 def set_device(self, device):
115 self.device = device
116
117 def save(self, filename: str):
118 network=self.network.module
119 return torch.save(network.state_dict(), filename)
120
121 def load(self, model_state_dict: str):
122 return self.network.module.load_state_dict(torch.load(model_state_dict, map_location=lambda storage, loc: storage))
123
124 def move_batch(self, batch: TorchTuple, non_blocking=False):
125 return batch.to(self.device, non_blocking)
126
127 def eval(self):
128 self.network.eval()
129
130 def train(self):
131 self.network.train()
132
133 def save_bert(self, filename: str):
134 return torch.save(self.bert_encoder.state_dict(), filename)
135
136 def to(self, device):
137 assert isinstance(device, torch.device)
138 self.network.to(device)
139
140 def half(self):
141 self.network.half()

Callers 1

prepare_model_optimizerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected