(
self, model, tokenizer, step=None, epoch=None, best=False, last=False
)
| 269 | logger.info(f" Total optimization steps = {total_training_steps}") |
| 270 | |
| 271 | def _save_model_checkpoint( |
| 272 | self, model, tokenizer, step=None, epoch=None, best=False, last=False |
| 273 | ): |
| 274 | # Save model checkpoint |
| 275 | if step: |
| 276 | dir_name = f"checkpoint-step-{step}" |
| 277 | if epoch: |
| 278 | dir_name = f"checkpoint-epoch-{epoch}" |
| 279 | if best: |
| 280 | dir_name = "best_model" |
| 281 | if last: |
| 282 | dir_name = "last_model" |
| 283 | |
| 284 | output_dir = os.path.join(self.training_args.output_dir, dir_name) |
| 285 | if not os.path.exists(output_dir): |
| 286 | os.makedirs(output_dir) |
| 287 | |
| 288 | if isinstance(model, torch.nn.DataParallel): |
| 289 | model = model.module |
| 290 | |
| 291 | if isinstance(model, (WordCNNForClassification, LSTMForClassification)): |
| 292 | model.save_pretrained(output_dir) |
| 293 | elif isinstance(model, transformers.PreTrainedModel): |
| 294 | model.save_pretrained(output_dir) |
| 295 | tokenizer.save_pretrained(output_dir) |
| 296 | else: |
| 297 | state_dict = {k: v.cpu() for k, v in model.state_dict().items()} |
| 298 | torch.save( |
| 299 | state_dict, |
| 300 | os.path.join(output_dir, "pytorch_model.bin"), |
| 301 | ) |
| 302 | |
| 303 | def _tb_log(self, log, step): |
| 304 | if not hasattr(self, "_tb_writer"): |
no test coverage detected