MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / Running

Class Running

LanguageNetwork/BERT/predict.py:20–66  ·  view source on GitHub ↗

Run Model

Source from the content-addressed store, hash-verified

18
19
20class Running(object):
21 """Run Model"""
22
23 def __init__(self, args, device_id):
24 """
25 :param args: parser.parse_args()
26 :param device_id: 0 or -1
27 """
28 self.args = args
29 self.device_id = device_id
30 self.model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval',
31 'rnn_size']
32
33 self.device = "cpu" if self.args.visible_gpus == '-1' else "cuda"
34 logger.info('Device ID %d' % self.device_id)
35 logger.info('Device %s' % self.device)
36 torch.manual_seed(self.args.seed)
37 random.seed(self.args.seed)
38
39 if self.device_id >= 0:
40 torch.cuda.set_device(self.device_id)
41
42 init_logger(args.log_file)
43
44 try:
45 self.step = int(self.args.test_from.split('.')[-2].split('_')[-1])
46 except IndexError:
47 self.step = 0
48
49 logger.info('Loading checkpoint from %s' % self.args.test_from)
50 checkpoint = torch.load(self.args.test_from, map_location=lambda storage, loc: storage)
51 opt = vars(checkpoint['opt'])
52 for k in opt.keys():
53 if k in self.model_flags:
54 setattr(self.args, k, opt[k])
55
56 config = BertConfig.from_json_file(self.args.bert_config_path)
57 self.model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config)
58 self.model.load_cp(checkpoint)
59 self.model.eval()
60
61 def predict(self):
62
63 test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False),
64 self.args.batch_size, self.device, shuffle=False, is_test=True)
65 trainer = build_trainer(self.args, self.device_id, self.model, None)
66 trainer.predict(test_iter, self.step)
67
68
69if __name__ == '__main__':

Callers 1

predict.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected