MCPcopy
hub / github.com/microsoft/Cream / main

Function main

Cream/tools/train.py:41–235  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

39
40
41def main():
42 args, cfg = parse_config_args('super net training')
43
44 # resolve logging
45 output_dir = os.path.join(cfg.SAVE_PATH,
46 "{}-{}".format(datetime.date.today().strftime('%m%d'),
47 cfg.MODEL))
48
49 if args.local_rank == 0:
50 logger = get_logger(os.path.join(output_dir, "train.log"))
51 else:
52 logger = None
53
54 # initialize distributed parameters
55 torch.cuda.set_device(args.local_rank)
56 torch.distributed.init_process_group(backend='nccl', init_method='env://')
57 if args.local_rank == 0:
58 logger.info(
59 'Training on Process %d with %d GPUs.',
60 args.local_rank, cfg.NUM_GPU)
61
62 # fix random seeds
63 torch.manual_seed(cfg.SEED)
64 torch.cuda.manual_seed_all(cfg.SEED)
65 np.random.seed(cfg.SEED)
66 torch.backends.cudnn.deterministic = True
67 torch.backends.cudnn.benchmark = False
68
69 # generate supernet
70 model, sta_num, resolution = gen_supernet(
71 flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM,
72 flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM,
73 num_classes=cfg.DATASET.NUM_CLASSES,
74 drop_rate=cfg.NET.DROPOUT_RATE,
75 global_pool=cfg.NET.GP,
76 resunit=cfg.SUPERNET.RESUNIT,
77 dil_conv=cfg.SUPERNET.DIL_CONV,
78 slice=cfg.SUPERNET.SLICE,
79 verbose=cfg.VERBOSE,
80 logger=logger)
81
82 # initialize meta matching networks
83 MetaMN = MetaMatchingNetwork(cfg)
84
85 # number of choice blocks in supernet
86 choice_num = len(model.blocks[1][0])
87 if args.local_rank == 0:
88 logger.info('Supernet created, param count: %d', (
89 sum([m.numel() for m in model.parameters()])))
90 logger.info('resolution: %d', (resolution))
91 logger.info('choice number: %d', (choice_num))
92
93 #initialize prioritized board
94 prioritized_board = PrioritizedBoard(cfg, CHOICE_NUM=choice_num, sta_num=sta_num)
95
96 # initialize flops look-up table
97 model_est = FlopsEst(model)
98

Callers 1

train.pyFile · 0.70

Calls 15

parse_config_argsFunction · 0.90
get_loggerFunction · 0.90
gen_supernetFunction · 0.90
MetaMatchingNetworkClass · 0.90
PrioritizedBoardClass · 0.90
FlopsEstClass · 0.90
create_loaderFunction · 0.90
train_epochFunction · 0.90
validateFunction · 0.90
formatMethod · 0.80

Tested by

no test coverage detected