()
| 39 | |
| 40 | |
| 41 | def main(): |
| 42 | args, cfg = parse_config_args('super net training') |
| 43 | |
| 44 | # resolve logging |
| 45 | output_dir = os.path.join(cfg.SAVE_PATH, |
| 46 | "{}-{}".format(datetime.date.today().strftime('%m%d'), |
| 47 | cfg.MODEL)) |
| 48 | |
| 49 | if args.local_rank == 0: |
| 50 | logger = get_logger(os.path.join(output_dir, "train.log")) |
| 51 | else: |
| 52 | logger = None |
| 53 | |
| 54 | # initialize distributed parameters |
| 55 | torch.cuda.set_device(args.local_rank) |
| 56 | torch.distributed.init_process_group(backend='nccl', init_method='env://') |
| 57 | if args.local_rank == 0: |
| 58 | logger.info( |
| 59 | 'Training on Process %d with %d GPUs.', |
| 60 | args.local_rank, cfg.NUM_GPU) |
| 61 | |
| 62 | # fix random seeds |
| 63 | torch.manual_seed(cfg.SEED) |
| 64 | torch.cuda.manual_seed_all(cfg.SEED) |
| 65 | np.random.seed(cfg.SEED) |
| 66 | torch.backends.cudnn.deterministic = True |
| 67 | torch.backends.cudnn.benchmark = False |
| 68 | |
| 69 | # generate supernet |
| 70 | model, sta_num, resolution = gen_supernet( |
| 71 | flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM, |
| 72 | flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM, |
| 73 | num_classes=cfg.DATASET.NUM_CLASSES, |
| 74 | drop_rate=cfg.NET.DROPOUT_RATE, |
| 75 | global_pool=cfg.NET.GP, |
| 76 | resunit=cfg.SUPERNET.RESUNIT, |
| 77 | dil_conv=cfg.SUPERNET.DIL_CONV, |
| 78 | slice=cfg.SUPERNET.SLICE, |
| 79 | verbose=cfg.VERBOSE, |
| 80 | logger=logger) |
| 81 | |
| 82 | # initialize meta matching networks |
| 83 | MetaMN = MetaMatchingNetwork(cfg) |
| 84 | |
| 85 | # number of choice blocks in supernet |
| 86 | choice_num = len(model.blocks[1][0]) |
| 87 | if args.local_rank == 0: |
| 88 | logger.info('Supernet created, param count: %d', ( |
| 89 | sum([m.numel() for m in model.parameters()]))) |
| 90 | logger.info('resolution: %d', (resolution)) |
| 91 | logger.info('choice number: %d', (choice_num)) |
| 92 | |
| 93 | #initialize prioritized board |
| 94 | prioritized_board = PrioritizedBoard(cfg, CHOICE_NUM=choice_num, sta_num=sta_num) |
| 95 | |
| 96 | # initialize flops look-up table |
| 97 | model_est = FlopsEst(model) |
| 98 |
no test coverage detected