MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / validate

Method validate

LanguageNetwork/BERT/train.py:173–193  ·  view source on GitHub ↗
(self, step)

Source from the content-addressed store, hash-verified

171 trainer.train(self.train_iter, self.args.train_steps)
172
173 def validate(self, step):
174
175 logger.info('Loading checkpoint from %s' % self.args.validate_from)
176 checkpoint = torch.load(self.args.validate_from, map_location=lambda storage, loc: storage)
177
178 opt = vars(checkpoint['opt'])
179 for k in opt.keys():
180 if k in self.model_flags:
181 setattr(self.args, k, opt[k])
182 print(self.args)
183
184 config = BertConfig.from_json_file(self.args.bert_config_path)
185 model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config)
186 model.load_cp(checkpoint)
187 model.eval()
188
189 valid_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'valid', shuffle=False),
190 self.args.batch_size, self.device, shuffle=False, is_test=False)
191 trainer = build_trainer(self.args, self.device_id, model, None)
192 stats = trainer.validate(valid_iter, step)
193 return stats.xent()
194
195 def wait_and_validate(self):
196 time_step = 0

Callers 2

wait_and_validateMethod · 0.95
mainFunction · 0.45

Calls 5

load_cpMethod · 0.95
build_trainerFunction · 0.90
xentMethod · 0.80
loadMethod · 0.45
from_json_fileMethod · 0.45

Tested by

no test coverage detected