(args, config)
| 56 | |
| 57 | |
| 58 | def main(args, config): |
| 59 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( |
| 60 | config) |
| 61 | |
| 62 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") |
| 63 | model = build_model(config) |
| 64 | if not args.only_cpu: |
| 65 | model.cuda() |
| 66 | |
| 67 | if args.use_sync_bn: |
| 68 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| 69 | |
| 70 | logger.info(str(model)) |
| 71 | |
| 72 | optimizer = build_optimizer(config, model) |
| 73 | |
| 74 | if not args.only_cpu: |
| 75 | model = torch.nn.parallel.DistributedDataParallel( |
| 76 | model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) |
| 77 | model_without_ddp = model.module |
| 78 | else: |
| 79 | model_without_ddp = model |
| 80 | |
| 81 | loss_scaler = NativeScalerWithGradNormCount(grad_scaler_enabled=config.AMP_ENABLE) |
| 82 | |
| 83 | n_parameters = sum(p.numel() |
| 84 | for p in model.parameters() if p.requires_grad) |
| 85 | logger.info(f"number of params: {n_parameters}") |
| 86 | if hasattr(model_without_ddp, 'flops'): |
| 87 | flops = model_without_ddp.flops() |
| 88 | logger.info(f"number of GFLOPs: {flops / 1e9}") |
| 89 | |
| 90 | lr_scheduler = build_scheduler(config, optimizer, len( |
| 91 | data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) |
| 92 | |
| 93 | if config.DISTILL.ENABLED: |
| 94 | # we disable MIXUP and CUTMIX when knowledge distillation |
| 95 | assert len( |
| 96 | config.DISTILL.TEACHER_LOGITS_PATH) > 0, "Please fill in DISTILL.TEACHER_LOGITS_PATH" |
| 97 | criterion = SoftTargetCrossEntropy() |
| 98 | else: |
| 99 | if config.AUG.MIXUP > 0.: |
| 100 | # smoothing is handled with mixup label transform |
| 101 | criterion = SoftTargetCrossEntropy() |
| 102 | elif config.MODEL.LABEL_SMOOTHING > 0.: |
| 103 | criterion = LabelSmoothingCrossEntropy( |
| 104 | smoothing=config.MODEL.LABEL_SMOOTHING) |
| 105 | else: |
| 106 | criterion = torch.nn.CrossEntropyLoss() |
| 107 | |
| 108 | max_accuracy = 0.0 |
| 109 | |
| 110 | if config.TRAIN.AUTO_RESUME: |
| 111 | resume_file = auto_resume_helper(config.OUTPUT) |
| 112 | if resume_file: |
| 113 | if config.MODEL.RESUME: |
| 114 | logger.warning( |
| 115 | f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") |
no test coverage detected