(self, step=None)
| 237 | time.sleep(300) |
| 238 | |
| 239 | def test(self, step=None): |
| 240 | if not step: |
| 241 | try: |
| 242 | step = int(self.args.test_from.split('.')[-2].split('_')[-1]) |
| 243 | except IndexError: |
| 244 | step = 0 |
| 245 | |
| 246 | logger.info('Loading checkpoint from %s' % self.args.test_from) |
| 247 | checkpoint = torch.load(self.args.test_from, map_location=lambda storage, loc: storage) |
| 248 | opt = vars(checkpoint['opt']) |
| 249 | for k in opt.keys(): |
| 250 | if k in self.model_flags: |
| 251 | setattr(self.args, k, opt[k]) |
| 252 | |
| 253 | config = BertConfig.from_json_file(self.args.bert_config_path) |
| 254 | model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config) |
| 255 | model.load_cp(checkpoint) |
| 256 | model.eval() |
| 257 | # logger.info(model) |
| 258 | trainer = build_trainer(self.args, self.device_id, model, None) |
| 259 | test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False), |
| 260 | self.args.batch_size, self.device, shuffle=False, is_test=True) |
| 261 | trainer.test(test_iter, step) |
| 262 | |
| 263 | def gen_features_vector(self, step=None): |
| 264 | if not step: |
no test coverage detected