MCPcopy
hub / github.com/microsoft/Cream / main

Function main

Cream/tools/test.py:36–154  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

34
35
36def main():
37 args, cfg = parse_config_args('child net testing')
38
39 # resolve logging
40 output_dir = os.path.join(cfg.SAVE_PATH,
41 "{}-{}".format(datetime.date.today().strftime('%m%d'),
42 cfg.MODEL))
43
44 if args.local_rank == 0:
45 logger = get_logger(os.path.join(output_dir, 'test.log'))
46 writer = SummaryWriter(os.path.join(output_dir, 'runs'))
47 else:
48 writer, logger = None, None
49
50 # retrain model selection
51 if cfg.NET.SELECTION == 481:
52 arch_list = [
53 [0], [
54 3, 4, 3, 1], [
55 3, 2, 3, 0], [
56 3, 3, 3, 1, 1], [
57 3, 3, 3, 3], [
58 3, 3, 3, 3], [0]]
59 cfg.DATASET.IMAGE_SIZE = 224
60 elif cfg.NET.SELECTION == 43:
61 arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
62 cfg.DATASET.IMAGE_SIZE = 96
63 elif cfg.NET.SELECTION == 14:
64 arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
65 cfg.DATASET.IMAGE_SIZE = 64
66 elif cfg.NET.SELECTION == 114:
67 arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
68 cfg.DATASET.IMAGE_SIZE = 160
69 elif cfg.NET.SELECTION == 287:
70 arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
71 cfg.DATASET.IMAGE_SIZE = 224
72 elif cfg.NET.SELECTION == 604:
73 arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
74 [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
75 cfg.DATASET.IMAGE_SIZE = 224
76 else:
77 raise ValueError("Model Test Selection is not Supported!")
78
79 # define childnet architecture from arch_list
80 stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
81 choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
82 'ir_r1_k5_s2_e4_c40_se0.25',
83 'ir_r1_k3_s2_e6_c80_se0.25',
84 'ir_r1_k3_s1_e6_c96_se0.25',
85 'ir_r1_k5_s2_e6_c192_se0.25']
86 arch_def = [[stem[0]]] + [[choice_block_pool[idx]
87 for repeat_times in range(len(arch_list[idx + 1]))]
88 for idx in range(len(choice_block_pool))] + [[stem[1]]]
89
90 # generate childnet
91 model = gen_childnet(
92 arch_list,
93 arch_def,

Callers 1

test.pyFile · 0.70

Calls 8

parse_config_argsFunction · 0.90
get_loggerFunction · 0.90
gen_childnetFunction · 0.90
get_model_flops_paramsFunction · 0.90
ModelEmaClass · 0.90
create_loaderFunction · 0.90
validateFunction · 0.90
formatMethod · 0.80

Tested by

no test coverage detected