(args)
| 19 | |
| 20 | |
| 21 | def main(args) -> None: |
| 22 | # Setup accelerator: |
| 23 | accelerator = Accelerator(split_batches=True) |
| 24 | set_seed(231) |
| 25 | device = accelerator.device |
| 26 | cfg = OmegaConf.load(args.config) |
| 27 | |
| 28 | # Setup an experiment folder: |
| 29 | if accelerator.is_local_main_process: |
| 30 | exp_dir = cfg.train.exp_dir |
| 31 | os.makedirs(exp_dir, exist_ok=True) |
| 32 | ckpt_dir = os.path.join(exp_dir, "checkpoints") |
| 33 | os.makedirs(ckpt_dir, exist_ok=True) |
| 34 | print(f"Experiment directory created at {exp_dir}") |
| 35 | |
| 36 | # Create model: |
| 37 | swinir: SwinIR = instantiate_from_config(cfg.model.swinir) |
| 38 | if cfg.train.resume: |
| 39 | swinir.load_state_dict( |
| 40 | torch.load(cfg.train.resume, map_location="cpu"), strict=True |
| 41 | ) |
| 42 | if accelerator.is_local_main_process: |
| 43 | print(f"strictly load weight from checkpoint: {cfg.train.resume}") |
| 44 | else: |
| 45 | if accelerator.is_local_main_process: |
| 46 | print("initialize from scratch") |
| 47 | |
| 48 | # Setup optimizer: |
| 49 | opt = torch.optim.AdamW( |
| 50 | swinir.parameters(), lr=cfg.train.learning_rate, weight_decay=0 |
| 51 | ) |
| 52 | |
| 53 | # Setup data: |
| 54 | dataset = instantiate_from_config(cfg.dataset.train) |
| 55 | loader = DataLoader( |
| 56 | dataset=dataset, |
| 57 | batch_size=cfg.train.batch_size, |
| 58 | num_workers=cfg.train.num_workers, |
| 59 | shuffle=True, |
| 60 | drop_last=True, |
| 61 | ) |
| 62 | val_dataset = instantiate_from_config(cfg.dataset.val) |
| 63 | val_loader = DataLoader( |
| 64 | dataset=val_dataset, |
| 65 | batch_size=cfg.train.batch_size, |
| 66 | num_workers=cfg.train.num_workers, |
| 67 | shuffle=False, |
| 68 | drop_last=False, |
| 69 | ) |
| 70 | if accelerator.is_local_main_process: |
| 71 | print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}") |
| 72 | |
| 73 | batch_transform = instantiate_from_config(cfg.batch_transform) |
| 74 | |
| 75 | # Prepare models for training: |
| 76 | swinir.train().to(device) |
| 77 | swinir, opt, loader, val_loader = accelerator.prepare( |
| 78 | swinir, opt, loader, val_loader |
no test coverage detected