(config)
| 74 | |
| 75 | |
| 76 | def main(config): |
| 77 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, |
| 78 | is_pretrain=False) |
| 79 | |
| 80 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") |
| 81 | model = build_model(config, is_pretrain=False) |
| 82 | model.cuda() |
| 83 | logger.info(str(model)) |
| 84 | |
| 85 | optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False) |
| 86 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) |
| 87 | model_without_ddp = model.module |
| 88 | |
| 89 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| 90 | logger.info(f"number of params: {n_parameters}") |
| 91 | if hasattr(model_without_ddp, 'flops'): |
| 92 | flops = model_without_ddp.flops() |
| 93 | logger.info(f"number of GFLOPs: {flops / 1e9}") |
| 94 | |
| 95 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) |
| 96 | scaler = amp.GradScaler() |
| 97 | |
| 98 | if config.AUG.MIXUP > 0.: |
| 99 | # smoothing is handled with mixup label transform |
| 100 | criterion = SoftTargetCrossEntropy() |
| 101 | elif config.MODEL.LABEL_SMOOTHING > 0.: |
| 102 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) |
| 103 | else: |
| 104 | criterion = torch.nn.CrossEntropyLoss() |
| 105 | |
| 106 | max_accuracy = 0.0 |
| 107 | |
| 108 | if config.TRAIN.AUTO_RESUME: |
| 109 | resume_file = auto_resume_helper(config.OUTPUT, logger) |
| 110 | if resume_file: |
| 111 | if config.MODEL.RESUME: |
| 112 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") |
| 113 | config.defrost() |
| 114 | config.MODEL.RESUME = resume_file |
| 115 | config.freeze() |
| 116 | logger.info(f'auto resuming from {resume_file}') |
| 117 | else: |
| 118 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') |
| 119 | |
| 120 | if config.MODEL.RESUME: |
| 121 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) |
| 122 | acc1, acc5, loss = validate(config, data_loader_val, model) |
| 123 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") |
| 124 | if config.EVAL_MODE: |
| 125 | return |
| 126 | |
| 127 | if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): |
| 128 | load_pretrained(config, model_without_ddp, logger) |
| 129 | acc1, acc5, loss = validate(config, data_loader_val, model) |
| 130 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") |
| 131 | |
| 132 | if config.THROUGHPUT_MODE: |
| 133 | throughput(data_loader_val, model, logger) |
no test coverage detected