(self, step=None)
| 261 | trainer.test(test_iter, step) |
| 262 | |
| 263 | def gen_features_vector(self, step=None): |
| 264 | if not step: |
| 265 | try: |
| 266 | step = int(self.args.test_from.split('.')[-2].split('_')[-1]) |
| 267 | except IndexError: |
| 268 | step = 0 |
| 269 | |
| 270 | logger.info('Loading checkpoint from %s' % self.args.test_from) |
| 271 | checkpoint = torch.load(self.args.test_from, map_location=lambda storage, loc: storage) |
| 272 | opt = vars(checkpoint['opt']) |
| 273 | for k in opt.keys(): |
| 274 | if k in self.model_flags: |
| 275 | setattr(self.args, k, opt[k]) |
| 276 | |
| 277 | config = BertConfig.from_json_file(self.args.bert_config_path) |
| 278 | model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config) |
| 279 | model.load_cp(checkpoint) |
| 280 | model.eval() |
| 281 | # logger.info(model) |
| 282 | trainer = build_trainer(self.args, self.device_id, model, None) |
| 283 | test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False), |
| 284 | self.args.batch_size, self.device, shuffle=False, is_test=True) |
| 285 | trainer.gen_features_vector(test_iter, step) |
| 286 | |
| 287 | |
| 288 | if __name__ == '__main__': |
no test coverage detected