MCPcopy
hub / github.com/microsoft/Cream / main

Function main

CDARTS/CDARTS/test.py:45–93  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

43 raise Exception("Not support dataser!")
44
45def main():
46 logger.info("Logger is set - training start")
47
48 # set seed
49 np.random.seed(config.seed)
50 torch.manual_seed(config.seed)
51 torch.cuda.manual_seed_all(config.seed)
52 torch.backends.cudnn.deterministic = True
53 torch.backends.cudnn.benchmark = True
54
55 if config.distributed:
56 config.gpu = config.local_rank % torch.cuda.device_count()
57 torch.cuda.set_device(config.gpu)
58 # distributed init
59 torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url,
60 world_size=config.world_size, rank=config.local_rank)
61
62 config.world_size = torch.distributed.get_world_size()
63
64 config.total_batch_size = config.world_size * config.batch_size
65 else:
66 config.total_batch_size = config.batch_size
67
68 loaders, samplers = get_augment_datasets(config)
69 train_loader, valid_loader = loaders
70 train_sampler, valid_sampler = samplers
71
72 file = open(config.cell_file, 'r')
73 js = file.read()
74 r_dict = json.loads(js)
75 if config.local_rank == 0:
76 logger.info(r_dict)
77 file.close()
78 genotypes_dict = {}
79 for layer_idx, genotype in r_dict.items():
80 genotypes_dict[int(layer_idx)] = gt.from_str(genotype)
81
82 model_main = ModelTest(genotypes_dict, config.model_type, config.res_stem, init_channel=config.init_channels, \
83 stem_multiplier=config.stem_multiplier, n_nodes=4, num_classes=config.n_classes)
84 resume_state = torch.load(config.resume_path, map_location='cpu')
85 model_main.load_state_dict(resume_state, strict=False)
86 model_main = model_main.cuda()
87
88 if config.distributed:
89 model_main = DDP(model_main, delay_allreduce=True)
90
91 top1, top5 = validate(valid_loader, model_main, 0, 0, writer, logger, config)
92 if config.local_rank == 0:
93 print("Final best Prec@1 = {:.4%}, Prec@5 = {:.4%}".format(top1, top5))
94
95if __name__ == "__main__":
96 main()

Callers 1

test.pyFile · 0.70

Calls 7

get_augment_datasetsFunction · 0.90
ModelTestClass · 0.90
validateFunction · 0.90
formatMethod · 0.80
printFunction · 0.50
readMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected