MCPcopy
hub / github.com/microsoft/Swin-Transformer / main

Function main

main_simmim_pt.py:70–117  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

68
69
70def main(config):
71 data_loader_train = build_loader(config, simmim=True, is_pretrain=True)
72
73 logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
74 model = build_model(config, is_pretrain=True)
75 model.cuda()
76 logger.info(str(model))
77
78 optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True)
79 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
80 model_without_ddp = model.module
81
82 n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
83 logger.info(f"number of params: {n_parameters}")
84 if hasattr(model_without_ddp, 'flops'):
85 flops = model_without_ddp.flops()
86 logger.info(f"number of GFLOPs: {flops / 1e9}")
87
88 lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
89 scaler = amp.GradScaler()
90
91 if config.TRAIN.AUTO_RESUME:
92 resume_file = auto_resume_helper(config.OUTPUT, logger)
93 if resume_file:
94 if config.MODEL.RESUME:
95 logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
96 config.defrost()
97 config.MODEL.RESUME = resume_file
98 config.freeze()
99 logger.info(f'auto resuming from {resume_file}')
100 else:
101 logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
102
103 if config.MODEL.RESUME:
104 load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
105
106 logger.info("Start training")
107 start_time = time.time()
108 for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
109 data_loader_train.sampler.set_epoch(epoch)
110
111 train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler)
112 if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
113 save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger)
114
115 total_time = time.time() - start_time
116 total_time_str = str(datetime.timedelta(seconds=int(total_time)))
117 logger.info('Training time {}'.format(total_time_str))
118
119
120def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler):

Callers 1

main_simmim_pt.pyFile · 0.70

Calls 10

build_loaderFunction · 0.90
build_modelFunction · 0.90
build_optimizerFunction · 0.90
build_schedulerFunction · 0.90
auto_resume_helperFunction · 0.90
load_checkpointFunction · 0.90
save_checkpointFunction · 0.90
set_epochMethod · 0.80
train_one_epochFunction · 0.70
flopsMethod · 0.45

Tested by

no test coverage detected