:param args: parser.parse_args() :param device_id: 0 or -1
(self, args, device_id)
| 21 | """Run Model""" |
| 22 | |
| 23 | def __init__(self, args, device_id): |
| 24 | """ |
| 25 | :param args: parser.parse_args() |
| 26 | :param device_id: 0 or -1 |
| 27 | """ |
| 28 | self.args = args |
| 29 | self.device_id = device_id |
| 30 | self.model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', |
| 31 | 'rnn_size'] |
| 32 | |
| 33 | self.device = "cpu" if self.args.visible_gpus == '-1' else "cuda" |
| 34 | logger.info('Device ID %d' % self.device_id) |
| 35 | logger.info('Device %s' % self.device) |
| 36 | torch.manual_seed(self.args.seed) |
| 37 | random.seed(self.args.seed) |
| 38 | |
| 39 | if self.device_id >= 0: |
| 40 | torch.cuda.set_device(self.device_id) |
| 41 | |
| 42 | init_logger(args.log_file) |
| 43 | |
| 44 | try: |
| 45 | self.step = int(self.args.test_from.split('.')[-2].split('_')[-1]) |
| 46 | except IndexError: |
| 47 | self.step = 0 |
| 48 | |
| 49 | logger.info('Loading checkpoint from %s' % self.args.test_from) |
| 50 | checkpoint = torch.load(self.args.test_from, map_location=lambda storage, loc: storage) |
| 51 | opt = vars(checkpoint['opt']) |
| 52 | for k in opt.keys(): |
| 53 | if k in self.model_flags: |
| 54 | setattr(self.args, k, opt[k]) |
| 55 | |
| 56 | config = BertConfig.from_json_file(self.args.bert_config_path) |
| 57 | self.model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config) |
| 58 | self.model.load_cp(checkpoint) |
| 59 | self.model.eval() |
| 60 | |
| 61 | def predict(self): |
| 62 |
nothing calls this directly
no test coverage detected