(config)
| 48 | |
| 49 | |
| 50 | def main(config): |
| 51 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( |
| 52 | config) |
| 53 | |
| 54 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") |
| 55 | model = build_model(config) |
| 56 | model.cuda() |
| 57 | |
| 58 | logger.info(str(model)) |
| 59 | |
| 60 | model = torch.nn.parallel.DistributedDataParallel( |
| 61 | model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) |
| 62 | model_without_ddp = model.module |
| 63 | |
| 64 | n_parameters = sum(p.numel() |
| 65 | for p in model.parameters() if p.requires_grad) |
| 66 | logger.info(f"number of params: {n_parameters}") |
| 67 | |
| 68 | optimizer = None |
| 69 | lr_scheduler = None |
| 70 | |
| 71 | assert config.MODEL.RESUME |
| 72 | loss_scaler = NativeScalerWithGradNormCount() |
| 73 | load_checkpoint(config, model_without_ddp, optimizer, |
| 74 | lr_scheduler, loss_scaler, logger) |
| 75 | if not args.skip_eval and not args.check_saved_logits: |
| 76 | acc1, acc5, loss = validate(config, data_loader_val, model) |
| 77 | logger.info( |
| 78 | f"Accuracy of the network on the {len(dataset_val)} test images: top-1 acc: {acc1:.1f}%, top-5 acc: {acc5:.1f}%") |
| 79 | |
| 80 | if args.check_saved_logits: |
| 81 | logger.info("Start checking logits") |
| 82 | else: |
| 83 | logger.info("Start saving logits") |
| 84 | |
| 85 | start_time = time.time() |
| 86 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): |
| 87 | dataset_train.set_epoch(epoch) |
| 88 | data_loader_train.sampler.set_epoch(epoch) |
| 89 | |
| 90 | if args.check_saved_logits: |
| 91 | check_logits_one_epoch( |
| 92 | config, model, data_loader_train, epoch, mixup_fn=mixup_fn) |
| 93 | else: |
| 94 | save_logits_one_epoch( |
| 95 | config, model, data_loader_train, epoch, mixup_fn=mixup_fn) |
| 96 | |
| 97 | total_time = time.time() - start_time |
| 98 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| 99 | logger.info('Saving logits time {}'.format(total_time_str)) |
| 100 | |
| 101 | |
| 102 | @torch.no_grad() |
no test coverage detected