MCPcopy
hub / github.com/XPixelGroup/DiffBIR / main

Function main

train_stage2.py:20–241  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

18
19
20def main(args) -> None:
21 # Setup accelerator:
22 accelerator = Accelerator(split_batches=True)
23 set_seed(231, device_specific=True)
24 device = accelerator.device
25 cfg = OmegaConf.load(args.config)
26
27 # Setup an experiment folder:
28 if accelerator.is_main_process:
29 exp_dir = cfg.train.exp_dir
30 os.makedirs(exp_dir, exist_ok=True)
31 ckpt_dir = os.path.join(exp_dir, "checkpoints")
32 os.makedirs(ckpt_dir, exist_ok=True)
33 print(f"Experiment directory created at {exp_dir}")
34
35 # Create model:
36 cldm: ControlLDM = instantiate_from_config(cfg.model.cldm)
37 sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"]
38 unused, missing = cldm.load_pretrained_sd(sd)
39 if accelerator.is_main_process:
40 print(
41 f"strictly load pretrained SD weight from {cfg.train.sd_path}\n"
42 f"unused weights: {unused}\n"
43 f"missing weights: {missing}"
44 )
45
46 if cfg.train.resume:
47 cldm.load_controlnet_from_ckpt(torch.load(cfg.train.resume, map_location="cpu"))
48 if accelerator.is_main_process:
49 print(
50 f"strictly load controlnet weight from checkpoint: {cfg.train.resume}"
51 )
52 else:
53 init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet()
54 if accelerator.is_main_process:
55 print(
56 f"strictly load controlnet weight from pretrained SD\n"
57 f"weights initialized with newly added zeros: {init_with_new_zero}\n"
58 f"weights initialized from scratch: {init_with_scratch}"
59 )
60
61 swinir: SwinIR = instantiate_from_config(cfg.model.swinir)
62 sd = torch.load(cfg.train.swinir_path, map_location="cpu")
63 if "state_dict" in sd:
64 sd = sd["state_dict"]
65 sd = {
66 (k[len("module.") :] if k.startswith("module.") else k): v
67 for k, v in sd.items()
68 }
69 swinir.load_state_dict(sd, strict=True)
70 for p in swinir.parameters():
71 p.requires_grad = False
72 if accelerator.is_main_process:
73 print(f"load SwinIR from {cfg.train.swinir_path}")
74
75 diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion)
76
77 # Setup optimizer:

Callers 1

train_stage2.pyFile · 0.70

Calls 15

sampleMethod · 0.95
instantiate_from_configFunction · 0.90
SpacedSamplerClass · 0.90
toFunction · 0.90
log_txt_as_imgFunction · 0.90
load_pretrained_sdMethod · 0.80
vae_encodeMethod · 0.80
prepare_conditionMethod · 0.80
q_sampleMethod · 0.80
p_lossesMethod · 0.80

Tested by

no test coverage detected