MCPcopy
hub / github.com/microsoft/Swin-Transformer / main

Function main

main.py:90–171  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

88
89
90def main(config):
91 dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
92
93 logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
94 model = build_model(config)
95 logger.info(str(model))
96
97 n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
98 logger.info(f"number of params: {n_parameters}")
99 if hasattr(model, 'flops'):
100 flops = model.flops()
101 logger.info(f"number of GFLOPs: {flops / 1e9}")
102
103 model.cuda()
104 model_without_ddp = model
105
106 optimizer = build_optimizer(config, model)
107 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
108 loss_scaler = NativeScalerWithGradNormCount()
109
110 if config.TRAIN.ACCUMULATION_STEPS > 1:
111 lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
112 else:
113 lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
114
115 if config.AUG.MIXUP > 0.:
116 # smoothing is handled with mixup label transform
117 criterion = SoftTargetCrossEntropy()
118 elif config.MODEL.LABEL_SMOOTHING > 0.:
119 criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
120 else:
121 criterion = torch.nn.CrossEntropyLoss()
122
123 max_accuracy = 0.0
124
125 if config.TRAIN.AUTO_RESUME:
126 resume_file = auto_resume_helper(config.OUTPUT)
127 if resume_file:
128 if config.MODEL.RESUME:
129 logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
130 config.defrost()
131 config.MODEL.RESUME = resume_file
132 config.freeze()
133 logger.info(f'auto resuming from {resume_file}')
134 else:
135 logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
136
137 if config.MODEL.RESUME:
138 max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
139 acc1, acc5, loss = validate(config, data_loader_val, model)
140 logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
141 if config.EVAL_MODE:
142 return
143
144 if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
145 load_pretrained(config, model_without_ddp, logger)
146 acc1, acc5, loss = validate(config, data_loader_val, model)
147 logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")

Callers 1

main.pyFile · 0.70

Calls 14

build_loaderFunction · 0.90
build_modelFunction · 0.90
build_optimizerFunction · 0.90
build_schedulerFunction · 0.90
auto_resume_helperFunction · 0.90
load_checkpointFunction · 0.90
load_pretrainedFunction · 0.90
save_checkpointFunction · 0.90
set_epochMethod · 0.80
validateFunction · 0.70
throughputFunction · 0.70

Tested by

no test coverage detected