(config)
| 68 | |
| 69 | |
| 70 | def main(config): |
| 71 | data_loader_train = build_loader(config, simmim=True, is_pretrain=True) |
| 72 | |
| 73 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") |
| 74 | model = build_model(config, is_pretrain=True) |
| 75 | model.cuda() |
| 76 | logger.info(str(model)) |
| 77 | |
| 78 | optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True) |
| 79 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) |
| 80 | model_without_ddp = model.module |
| 81 | |
| 82 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| 83 | logger.info(f"number of params: {n_parameters}") |
| 84 | if hasattr(model_without_ddp, 'flops'): |
| 85 | flops = model_without_ddp.flops() |
| 86 | logger.info(f"number of GFLOPs: {flops / 1e9}") |
| 87 | |
| 88 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) |
| 89 | scaler = amp.GradScaler() |
| 90 | |
| 91 | if config.TRAIN.AUTO_RESUME: |
| 92 | resume_file = auto_resume_helper(config.OUTPUT, logger) |
| 93 | if resume_file: |
| 94 | if config.MODEL.RESUME: |
| 95 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") |
| 96 | config.defrost() |
| 97 | config.MODEL.RESUME = resume_file |
| 98 | config.freeze() |
| 99 | logger.info(f'auto resuming from {resume_file}') |
| 100 | else: |
| 101 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') |
| 102 | |
| 103 | if config.MODEL.RESUME: |
| 104 | load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) |
| 105 | |
| 106 | logger.info("Start training") |
| 107 | start_time = time.time() |
| 108 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): |
| 109 | data_loader_train.sampler.set_epoch(epoch) |
| 110 | |
| 111 | train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler) |
| 112 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): |
| 113 | save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger) |
| 114 | |
| 115 | total_time = time.time() - start_time |
| 116 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| 117 | logger.info('Training time {}'.format(total_time_str)) |
| 118 | |
| 119 | |
| 120 | def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler): |
no test coverage detected