| 90 | self.save_model(args, state, kwargs) |
| 91 | |
| 92 | def get_last_checkpoint(checkpoint_dir): |
| 93 | if os.path.isdir(checkpoint_dir): |
| 94 | is_completed = os.path.exists(os.path.join(checkpoint_dir, 'completed')) |
| 95 | if is_completed: return None # already finished |
| 96 | max_step = 0 |
| 97 | for filename in os.listdir(checkpoint_dir): |
| 98 | if os.path.isdir(os.path.join(checkpoint_dir, filename)) and filename.startswith(PREFIX_CHECKPOINT_DIR): |
| 99 | max_step = max(max_step, int(filename.replace(PREFIX_CHECKPOINT_DIR + '-', ''))) |
| 100 | if max_step == 0: return None |
| 101 | latest_ckpt_dir = os.path.join(checkpoint_dir, f'{PREFIX_CHECKPOINT_DIR}-{max_step}') |
| 102 | logger.info(f"Found a previous checkpoint at: {checkpoint_dir}") |
| 103 | return latest_ckpt_dir |
| 104 | return None # first training |
| 105 | |
| 106 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): |
| 107 | """Collects the state dict and dump to disk.""" |