(mp_trainer, opt, step)
| 174 | |
| 175 | |
| 176 | def save_model(mp_trainer, opt, step): |
| 177 | if dist.get_rank() == 0: |
| 178 | th.save( |
| 179 | mp_trainer.master_params_to_state_dict(mp_trainer.master_params), |
| 180 | os.path.join(logger.get_dir(), f"model{step:06d}.pt"), |
| 181 | ) |
| 182 | th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt")) |
| 183 | |
| 184 | |
| 185 | def compute_top_k(logits, labels, k, reduction="mean"): |
no test coverage detected