| 152 | self.args.batch_size, self.device, shuffle=True, is_test=False) |
| 153 | |
| 154 | def train(self): |
| 155 | model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=True) |
| 156 | |
| 157 | if self.args.train_from: |
| 158 | logger.info('Loading checkpoint from %s' % self.args.train_from) |
| 159 | checkpoint = torch.load(self.args.train_from, map_location=lambda storage, loc: storage) |
| 160 | opt = vars(checkpoint['opt']) |
| 161 | for k in opt.keys(): |
| 162 | if k in self.model_flags: |
| 163 | setattr(self.args, k, opt[k]) |
| 164 | model.load_cp(checkpoint) |
| 165 | optimizer = model_builder.build_optim(self.args, model, checkpoint) |
| 166 | else: |
| 167 | optimizer = model_builder.build_optim(self.args, model, None) |
| 168 | |
| 169 | logger.info(model) |
| 170 | trainer = build_trainer(self.args, self.device_id, model, optimizer) |
| 171 | trainer.train(self.train_iter, self.args.train_steps) |
| 172 | |
| 173 | def validate(self, step): |
| 174 | |