MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / gen_features_vector

Method gen_features_vector

LanguageNetwork/BERT/train.py:263–285  ·  view source on GitHub ↗
(self, step=None)

Source from the content-addressed store, hash-verified

261 trainer.test(test_iter, step)
262
263 def gen_features_vector(self, step=None):
264 if not step:
265 try:
266 step = int(self.args.test_from.split('.')[-2].split('_')[-1])
267 except IndexError:
268 step = 0
269
270 logger.info('Loading checkpoint from %s' % self.args.test_from)
271 checkpoint = torch.load(self.args.test_from, map_location=lambda storage, loc: storage)
272 opt = vars(checkpoint['opt'])
273 for k in opt.keys():
274 if k in self.model_flags:
275 setattr(self.args, k, opt[k])
276
277 config = BertConfig.from_json_file(self.args.bert_config_path)
278 model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config)
279 model.load_cp(checkpoint)
280 model.eval()
281 # logger.info(model)
282 trainer = build_trainer(self.args, self.device_id, model, None)
283 test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False),
284 self.args.batch_size, self.device, shuffle=False, is_test=True)
285 trainer.gen_features_vector(test_iter, step)
286
287
288if __name__ == '__main__':

Callers 1

train.pyFile · 0.80

Calls 4

load_cpMethod · 0.95
build_trainerFunction · 0.90
loadMethod · 0.45
from_json_fileMethod · 0.45

Tested by

no test coverage detected