(self)
| 400 | return self.step + self.resume_step |
| 401 | |
| 402 | def save(self): |
| 403 | def save_checkpoint(): |
| 404 | def del_clip(state_dict): |
| 405 | # Do not save CLIP weights |
| 406 | clip_weights = [ |
| 407 | e for e in state_dict.keys() if e.startswith('clip_model.') |
| 408 | ] |
| 409 | for e in clip_weights: |
| 410 | del state_dict[e] |
| 411 | |
| 412 | if self.use_fp16: |
| 413 | state_dict = self.model.state_dict() |
| 414 | else: |
| 415 | state_dict = self.mp_trainer.master_params_to_state_dict( |
| 416 | self.mp_trainer.master_params) |
| 417 | del_clip(state_dict) |
| 418 | |
| 419 | if self.args.use_ema: |
| 420 | # save both the model and the average model |
| 421 | state_dict_avg = self.model_avg.state_dict() |
| 422 | del_clip(state_dict_avg) |
| 423 | state_dict = {'model': state_dict, 'model_avg': state_dict_avg} |
| 424 | |
| 425 | logger.log(f"saving model...") |
| 426 | filename = self.ckpt_file_name() |
| 427 | with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: |
| 428 | torch.save(state_dict, f) |
| 429 | |
| 430 | save_checkpoint() |
| 431 | |
| 432 | with bf.BlobFile( |
| 433 | bf.join(self.save_dir, f"opt{(self.total_step()):09d}.pt"), |
| 434 | "wb", |
| 435 | ) as f: |
| 436 | opt_state = self.opt.state_dict() |
| 437 | if self.use_fp16: |
| 438 | # with fp16 we also save the state dict |
| 439 | opt_state = { |
| 440 | 'opt': opt_state, |
| 441 | 'scaler': self.scaler.state_dict(), |
| 442 | } |
| 443 | |
| 444 | torch.save(opt_state, f) |
| 445 | |
| 446 | |
| 447 | def parse_resume_step_from_filename(filename): |
no test coverage detected