MCPcopy
hub / github.com/XPixelGroup/DiffBIR / main

Function main

train_stage1.py:21–254  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

19
20
21def main(args) -> None:
22 # Setup accelerator:
23 accelerator = Accelerator(split_batches=True)
24 set_seed(231)
25 device = accelerator.device
26 cfg = OmegaConf.load(args.config)
27
28 # Setup an experiment folder:
29 if accelerator.is_local_main_process:
30 exp_dir = cfg.train.exp_dir
31 os.makedirs(exp_dir, exist_ok=True)
32 ckpt_dir = os.path.join(exp_dir, "checkpoints")
33 os.makedirs(ckpt_dir, exist_ok=True)
34 print(f"Experiment directory created at {exp_dir}")
35
36 # Create model:
37 swinir: SwinIR = instantiate_from_config(cfg.model.swinir)
38 if cfg.train.resume:
39 swinir.load_state_dict(
40 torch.load(cfg.train.resume, map_location="cpu"), strict=True
41 )
42 if accelerator.is_local_main_process:
43 print(f"strictly load weight from checkpoint: {cfg.train.resume}")
44 else:
45 if accelerator.is_local_main_process:
46 print("initialize from scratch")
47
48 # Setup optimizer:
49 opt = torch.optim.AdamW(
50 swinir.parameters(), lr=cfg.train.learning_rate, weight_decay=0
51 )
52
53 # Setup data:
54 dataset = instantiate_from_config(cfg.dataset.train)
55 loader = DataLoader(
56 dataset=dataset,
57 batch_size=cfg.train.batch_size,
58 num_workers=cfg.train.num_workers,
59 shuffle=True,
60 drop_last=True,
61 )
62 val_dataset = instantiate_from_config(cfg.dataset.val)
63 val_loader = DataLoader(
64 dataset=val_dataset,
65 batch_size=cfg.train.batch_size,
66 num_workers=cfg.train.num_workers,
67 shuffle=False,
68 drop_last=False,
69 )
70 if accelerator.is_local_main_process:
71 print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}")
72
73 batch_transform = instantiate_from_config(cfg.batch_transform)
74
75 # Prepare models for training:
76 swinir.train().to(device)
77 swinir, opt, loader, val_loader = accelerator.prepare(
78 swinir, opt, loader, val_loader

Callers 1

train_stage1.pyFile · 0.70

Calls 5

instantiate_from_configFunction · 0.90
toFunction · 0.90
calculate_psnr_ptFunction · 0.90
backwardMethod · 0.80
saveMethod · 0.45

Tested by

no test coverage detected