MCPcopy
hub / github.com/microsoft/Swin-Transformer / main

Function main

main_moe.py:86–181  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

84
85
86def main(config):
87 dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
88
89 logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
90 model = build_model(config)
91 logger.info(str(model))
92
93 # For Tutel MoE
94 for name, param in model.named_parameters():
95 if param.requires_grad == True and hasattr(param, 'skip_allreduce') and param.skip_allreduce is True:
96 model.add_param_to_skip_allreduce(name)
97 param.register_hook(partial(hook_scale_grad, dist.get_world_size()))
98 logger.info(f"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad")
99
100 n_parameters_single = sum(p.numel() * model.sharded_count if hasattr(p, 'skip_allreduce')
101 else p.numel() for p in model.parameters() if p.requires_grad)
102 logger.info(f"number of params single: {n_parameters_single}")
103 n_parameters_whole = sum(p.numel() * model.sharded_count * model.global_experts if hasattr(p, 'skip_allreduce')
104 else p.numel() for p in model.parameters() if p.requires_grad)
105 logger.info(f"number of params whole: {n_parameters_whole}")
106 if hasattr(model, 'flops'):
107 flops = model.flops()
108 logger.info(f"number of GFLOPs: {flops / 1e9}")
109
110 model.cuda(config.LOCAL_RANK)
111 model_without_ddp = model
112
113 optimizer = build_optimizer(config, model)
114 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
115 loss_scaler = NativeScalerWithGradNormCount()
116
117 if config.TRAIN.ACCUMULATION_STEPS > 1:
118 lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
119 else:
120 lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
121
122 if config.AUG.MIXUP > 0.:
123 # smoothing is handled with mixup label transform
124 criterion = SoftTargetCrossEntropy()
125 elif config.MODEL.LABEL_SMOOTHING > 0.:
126 criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
127 else:
128 criterion = torch.nn.CrossEntropyLoss()
129
130 max_accuracy = 0.0
131
132 if config.TRAIN.AUTO_RESUME:
133 resume_file = auto_resume_helper(config.OUTPUT, config.TRAIN.MOE.SAVE_MASTER)
134 if resume_file:
135 if config.MODEL.RESUME:
136 logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
137 config.defrost()
138 config.MODEL.RESUME = resume_file
139 config.freeze()
140 logger.info(f'auto resuming from {resume_file}')
141 else:
142 logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
143

Callers 1

main_moe.pyFile · 0.70

Calls 15

build_loaderFunction · 0.90
build_modelFunction · 0.90
build_optimizerFunction · 0.90
build_schedulerFunction · 0.90
auto_resume_helperFunction · 0.90
load_checkpointFunction · 0.90
load_pretrainedFunction · 0.90
save_checkpointFunction · 0.90
set_epochMethod · 0.80
validateFunction · 0.70

Tested by

no test coverage detected