()
| 43 | raise Exception("Not support dataser!") |
| 44 | |
| 45 | def main(): |
| 46 | logger.info("Logger is set - training start") |
| 47 | |
| 48 | # set seed |
| 49 | np.random.seed(config.seed) |
| 50 | torch.manual_seed(config.seed) |
| 51 | torch.cuda.manual_seed_all(config.seed) |
| 52 | torch.backends.cudnn.deterministic = True |
| 53 | torch.backends.cudnn.benchmark = True |
| 54 | |
| 55 | if config.distributed: |
| 56 | config.gpu = config.local_rank % torch.cuda.device_count() |
| 57 | torch.cuda.set_device(config.gpu) |
| 58 | # distributed init |
| 59 | torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url, |
| 60 | world_size=config.world_size, rank=config.local_rank) |
| 61 | |
| 62 | config.world_size = torch.distributed.get_world_size() |
| 63 | |
| 64 | config.total_batch_size = config.world_size * config.batch_size |
| 65 | else: |
| 66 | config.total_batch_size = config.batch_size |
| 67 | |
| 68 | loaders, samplers = get_augment_datasets(config) |
| 69 | train_loader, valid_loader = loaders |
| 70 | train_sampler, valid_sampler = samplers |
| 71 | |
| 72 | file = open(config.cell_file, 'r') |
| 73 | js = file.read() |
| 74 | r_dict = json.loads(js) |
| 75 | if config.local_rank == 0: |
| 76 | logger.info(r_dict) |
| 77 | file.close() |
| 78 | genotypes_dict = {} |
| 79 | for layer_idx, genotype in r_dict.items(): |
| 80 | genotypes_dict[int(layer_idx)] = gt.from_str(genotype) |
| 81 | |
| 82 | model_main = ModelTest(genotypes_dict, config.model_type, config.res_stem, init_channel=config.init_channels, \ |
| 83 | stem_multiplier=config.stem_multiplier, n_nodes=4, num_classes=config.n_classes) |
| 84 | resume_state = torch.load(config.resume_path, map_location='cpu') |
| 85 | model_main.load_state_dict(resume_state, strict=False) |
| 86 | model_main = model_main.cuda() |
| 87 | |
| 88 | if config.distributed: |
| 89 | model_main = DDP(model_main, delay_allreduce=True) |
| 90 | |
| 91 | top1, top5 = validate(valid_loader, model_main, 0, 0, writer, logger, config) |
| 92 | if config.local_rank == 0: |
| 93 | print("Final best Prec@1 = {:.4%}, Prec@5 = {:.4%}".format(top1, top5)) |
| 94 | |
| 95 | if __name__ == "__main__": |
| 96 | main() |
no test coverage detected