MCPcopy
hub / github.com/microsoft/Cream / main

Function main

TinyViT/main.py:58–179  ·  view source on GitHub ↗
(args, config)

Source from the content-addressed store, hash-verified

56
57
58def main(args, config):
59 dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
60 config)
61
62 logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
63 model = build_model(config)
64 if not args.only_cpu:
65 model.cuda()
66
67 if args.use_sync_bn:
68 model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
69
70 logger.info(str(model))
71
72 optimizer = build_optimizer(config, model)
73
74 if not args.only_cpu:
75 model = torch.nn.parallel.DistributedDataParallel(
76 model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
77 model_without_ddp = model.module
78 else:
79 model_without_ddp = model
80
81 loss_scaler = NativeScalerWithGradNormCount(grad_scaler_enabled=config.AMP_ENABLE)
82
83 n_parameters = sum(p.numel()
84 for p in model.parameters() if p.requires_grad)
85 logger.info(f"number of params: {n_parameters}")
86 if hasattr(model_without_ddp, 'flops'):
87 flops = model_without_ddp.flops()
88 logger.info(f"number of GFLOPs: {flops / 1e9}")
89
90 lr_scheduler = build_scheduler(config, optimizer, len(
91 data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
92
93 if config.DISTILL.ENABLED:
94 # we disable MIXUP and CUTMIX when knowledge distillation
95 assert len(
96 config.DISTILL.TEACHER_LOGITS_PATH) > 0, "Please fill in DISTILL.TEACHER_LOGITS_PATH"
97 criterion = SoftTargetCrossEntropy()
98 else:
99 if config.AUG.MIXUP > 0.:
100 # smoothing is handled with mixup label transform
101 criterion = SoftTargetCrossEntropy()
102 elif config.MODEL.LABEL_SMOOTHING > 0.:
103 criterion = LabelSmoothingCrossEntropy(
104 smoothing=config.MODEL.LABEL_SMOOTHING)
105 else:
106 criterion = torch.nn.CrossEntropyLoss()
107
108 max_accuracy = 0.0
109
110 if config.TRAIN.AUTO_RESUME:
111 resume_file = auto_resume_helper(config.OUTPUT)
112 if resume_file:
113 if config.MODEL.RESUME:
114 logger.warning(
115 f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")

Callers 1

main.pyFile · 0.70

Calls 15

build_loaderFunction · 0.90
build_modelFunction · 0.90
build_optimizerFunction · 0.90
build_schedulerFunction · 0.90
auto_resume_helperFunction · 0.90
load_checkpointFunction · 0.90
load_pretrainedFunction · 0.90
save_checkpointFunction · 0.90
is_main_processFunction · 0.90
formatMethod · 0.80

Tested by

no test coverage detected