(args)
| 18 | |
| 19 | |
| 20 | def main(args) -> None: |
| 21 | # Setup accelerator: |
| 22 | accelerator = Accelerator(split_batches=True) |
| 23 | set_seed(231, device_specific=True) |
| 24 | device = accelerator.device |
| 25 | cfg = OmegaConf.load(args.config) |
| 26 | |
| 27 | # Setup an experiment folder: |
| 28 | if accelerator.is_main_process: |
| 29 | exp_dir = cfg.train.exp_dir |
| 30 | os.makedirs(exp_dir, exist_ok=True) |
| 31 | ckpt_dir = os.path.join(exp_dir, "checkpoints") |
| 32 | os.makedirs(ckpt_dir, exist_ok=True) |
| 33 | print(f"Experiment directory created at {exp_dir}") |
| 34 | |
| 35 | # Create model: |
| 36 | cldm: ControlLDM = instantiate_from_config(cfg.model.cldm) |
| 37 | sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"] |
| 38 | unused, missing = cldm.load_pretrained_sd(sd) |
| 39 | if accelerator.is_main_process: |
| 40 | print( |
| 41 | f"strictly load pretrained SD weight from {cfg.train.sd_path}\n" |
| 42 | f"unused weights: {unused}\n" |
| 43 | f"missing weights: {missing}" |
| 44 | ) |
| 45 | |
| 46 | if cfg.train.resume: |
| 47 | cldm.load_controlnet_from_ckpt(torch.load(cfg.train.resume, map_location="cpu")) |
| 48 | if accelerator.is_main_process: |
| 49 | print( |
| 50 | f"strictly load controlnet weight from checkpoint: {cfg.train.resume}" |
| 51 | ) |
| 52 | else: |
| 53 | init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet() |
| 54 | if accelerator.is_main_process: |
| 55 | print( |
| 56 | f"strictly load controlnet weight from pretrained SD\n" |
| 57 | f"weights initialized with newly added zeros: {init_with_new_zero}\n" |
| 58 | f"weights initialized from scratch: {init_with_scratch}" |
| 59 | ) |
| 60 | |
| 61 | swinir: SwinIR = instantiate_from_config(cfg.model.swinir) |
| 62 | sd = torch.load(cfg.train.swinir_path, map_location="cpu") |
| 63 | if "state_dict" in sd: |
| 64 | sd = sd["state_dict"] |
| 65 | sd = { |
| 66 | (k[len("module.") :] if k.startswith("module.") else k): v |
| 67 | for k, v in sd.items() |
| 68 | } |
| 69 | swinir.load_state_dict(sd, strict=True) |
| 70 | for p in swinir.parameters(): |
| 71 | p.requires_grad = False |
| 72 | if accelerator.is_main_process: |
| 73 | print(f"load SwinIR from {cfg.train.swinir_path}") |
| 74 | |
| 75 | diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion) |
| 76 | |
| 77 | # Setup optimizer: |
no test coverage detected