MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / train

Method train

LanguageNetwork/BERT/train.py:154–171  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

152 self.args.batch_size, self.device, shuffle=True, is_test=False)
153
154 def train(self):
155 model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=True)
156
157 if self.args.train_from:
158 logger.info('Loading checkpoint from %s' % self.args.train_from)
159 checkpoint = torch.load(self.args.train_from, map_location=lambda storage, loc: storage)
160 opt = vars(checkpoint['opt'])
161 for k in opt.keys():
162 if k in self.model_flags:
163 setattr(self.args, k, opt[k])
164 model.load_cp(checkpoint)
165 optimizer = model_builder.build_optim(self.args, model, checkpoint)
166 else:
167 optimizer = model_builder.build_optim(self.args, model, None)
168
169 logger.info(model)
170 trainer = build_trainer(self.args, self.device_id, model, optimizer)
171 trainer.train(self.train_iter, self.args.train_steps)
172
173 def validate(self, step):
174

Callers 6

multi_card_trainMethod · 0.95
mainFunction · 0.45
trainFunction · 0.45
train.pyFile · 0.45
mainFunction · 0.45
mainFunction · 0.45

Calls 3

load_cpMethod · 0.95
build_trainerFunction · 0.90
loadMethod · 0.45

Tested by

no test coverage detected