MCPcopy
hub / github.com/microsoft/Swin-Transformer / main

Function main

main_simmim_ft.py:76–152  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

74
75
76def main(config):
77 dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True,
78 is_pretrain=False)
79
80 logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
81 model = build_model(config, is_pretrain=False)
82 model.cuda()
83 logger.info(str(model))
84
85 optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False)
86 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
87 model_without_ddp = model.module
88
89 n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
90 logger.info(f"number of params: {n_parameters}")
91 if hasattr(model_without_ddp, 'flops'):
92 flops = model_without_ddp.flops()
93 logger.info(f"number of GFLOPs: {flops / 1e9}")
94
95 lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
96 scaler = amp.GradScaler()
97
98 if config.AUG.MIXUP > 0.:
99 # smoothing is handled with mixup label transform
100 criterion = SoftTargetCrossEntropy()
101 elif config.MODEL.LABEL_SMOOTHING > 0.:
102 criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
103 else:
104 criterion = torch.nn.CrossEntropyLoss()
105
106 max_accuracy = 0.0
107
108 if config.TRAIN.AUTO_RESUME:
109 resume_file = auto_resume_helper(config.OUTPUT, logger)
110 if resume_file:
111 if config.MODEL.RESUME:
112 logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
113 config.defrost()
114 config.MODEL.RESUME = resume_file
115 config.freeze()
116 logger.info(f'auto resuming from {resume_file}')
117 else:
118 logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
119
120 if config.MODEL.RESUME:
121 max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
122 acc1, acc5, loss = validate(config, data_loader_val, model)
123 logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
124 if config.EVAL_MODE:
125 return
126
127 if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
128 load_pretrained(config, model_without_ddp, logger)
129 acc1, acc5, loss = validate(config, data_loader_val, model)
130 logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
131
132 if config.THROUGHPUT_MODE:
133 throughput(data_loader_val, model, logger)

Callers 1

main_simmim_ft.pyFile · 0.70

Calls 13

build_loaderFunction · 0.90
build_modelFunction · 0.90
build_optimizerFunction · 0.90
build_schedulerFunction · 0.90
auto_resume_helperFunction · 0.90
load_checkpointFunction · 0.90
load_pretrainedFunction · 0.90
save_checkpointFunction · 0.90
set_epochMethod · 0.80
validateFunction · 0.70
throughputFunction · 0.70
train_one_epochFunction · 0.70

Tested by

no test coverage detected