Run Model
| 18 | |
| 19 | |
| 20 | class Running(object): |
| 21 | """Run Model""" |
| 22 | |
| 23 | def __init__(self, args, device_id): |
| 24 | """ |
| 25 | :param args: parser.parse_args() |
| 26 | :param device_id: 0 or -1 |
| 27 | """ |
| 28 | self.args = args |
| 29 | self.device_id = device_id |
| 30 | self.model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', |
| 31 | 'rnn_size'] |
| 32 | |
| 33 | self.device = "cpu" if self.args.visible_gpus == '-1' else "cuda" |
| 34 | logger.info('Device ID %d' % self.device_id) |
| 35 | logger.info('Device %s' % self.device) |
| 36 | torch.manual_seed(self.args.seed) |
| 37 | random.seed(self.args.seed) |
| 38 | |
| 39 | if self.device_id >= 0: |
| 40 | torch.cuda.set_device(self.device_id) |
| 41 | |
| 42 | init_logger(args.log_file) |
| 43 | |
| 44 | try: |
| 45 | self.step = int(self.args.test_from.split('.')[-2].split('_')[-1]) |
| 46 | except IndexError: |
| 47 | self.step = 0 |
| 48 | |
| 49 | logger.info('Loading checkpoint from %s' % self.args.test_from) |
| 50 | checkpoint = torch.load(self.args.test_from, map_location=lambda storage, loc: storage) |
| 51 | opt = vars(checkpoint['opt']) |
| 52 | for k in opt.keys(): |
| 53 | if k in self.model_flags: |
| 54 | setattr(self.args, k, opt[k]) |
| 55 | |
| 56 | config = BertConfig.from_json_file(self.args.bert_config_path) |
| 57 | self.model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config) |
| 58 | self.model.load_cp(checkpoint) |
| 59 | self.model.eval() |
| 60 | |
| 61 | def predict(self): |
| 62 | |
| 63 | test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False), |
| 64 | self.args.batch_size, self.device, shuffle=False, is_test=True) |
| 65 | trainer = build_trainer(self.args, self.device_id, self.model, None) |
| 66 | trainer.predict(test_iter, self.step) |
| 67 | |
| 68 | |
| 69 | if __name__ == '__main__': |