(self, step)
| 171 | trainer.train(self.train_iter, self.args.train_steps) |
| 172 | |
| 173 | def validate(self, step): |
| 174 | |
| 175 | logger.info('Loading checkpoint from %s' % self.args.validate_from) |
| 176 | checkpoint = torch.load(self.args.validate_from, map_location=lambda storage, loc: storage) |
| 177 | |
| 178 | opt = vars(checkpoint['opt']) |
| 179 | for k in opt.keys(): |
| 180 | if k in self.model_flags: |
| 181 | setattr(self.args, k, opt[k]) |
| 182 | print(self.args) |
| 183 | |
| 184 | config = BertConfig.from_json_file(self.args.bert_config_path) |
| 185 | model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config) |
| 186 | model.load_cp(checkpoint) |
| 187 | model.eval() |
| 188 | |
| 189 | valid_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'valid', shuffle=False), |
| 190 | self.args.batch_size, self.device, shuffle=False, is_test=False) |
| 191 | trainer = build_trainer(self.args, self.device_id, model, None) |
| 192 | stats = trainer.validate(valid_iter, step) |
| 193 | return stats.xent() |
| 194 | |
| 195 | def wait_and_validate(self): |
| 196 | time_step = 0 |
no test coverage detected