(config)
| 88 | |
| 89 | |
| 90 | def main(config): |
| 91 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) |
| 92 | |
| 93 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") |
| 94 | model = build_model(config) |
| 95 | logger.info(str(model)) |
| 96 | |
| 97 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| 98 | logger.info(f"number of params: {n_parameters}") |
| 99 | if hasattr(model, 'flops'): |
| 100 | flops = model.flops() |
| 101 | logger.info(f"number of GFLOPs: {flops / 1e9}") |
| 102 | |
| 103 | model.cuda() |
| 104 | model_without_ddp = model |
| 105 | |
| 106 | optimizer = build_optimizer(config, model) |
| 107 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) |
| 108 | loss_scaler = NativeScalerWithGradNormCount() |
| 109 | |
| 110 | if config.TRAIN.ACCUMULATION_STEPS > 1: |
| 111 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) |
| 112 | else: |
| 113 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) |
| 114 | |
| 115 | if config.AUG.MIXUP > 0.: |
| 116 | # smoothing is handled with mixup label transform |
| 117 | criterion = SoftTargetCrossEntropy() |
| 118 | elif config.MODEL.LABEL_SMOOTHING > 0.: |
| 119 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) |
| 120 | else: |
| 121 | criterion = torch.nn.CrossEntropyLoss() |
| 122 | |
| 123 | max_accuracy = 0.0 |
| 124 | |
| 125 | if config.TRAIN.AUTO_RESUME: |
| 126 | resume_file = auto_resume_helper(config.OUTPUT) |
| 127 | if resume_file: |
| 128 | if config.MODEL.RESUME: |
| 129 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") |
| 130 | config.defrost() |
| 131 | config.MODEL.RESUME = resume_file |
| 132 | config.freeze() |
| 133 | logger.info(f'auto resuming from {resume_file}') |
| 134 | else: |
| 135 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') |
| 136 | |
| 137 | if config.MODEL.RESUME: |
| 138 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) |
| 139 | acc1, acc5, loss = validate(config, data_loader_val, model) |
| 140 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") |
| 141 | if config.EVAL_MODE: |
| 142 | return |
| 143 | |
| 144 | if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): |
| 145 | load_pretrained(config, model_without_ddp, logger) |
| 146 | acc1, acc5, loss = validate(config, data_loader_val, model) |
| 147 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") |
no test coverage detected