MCPcopy
hub / github.com/microsoft/Cream / main

Function main

TinyViT/save_logits.py:50–99  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

48
49
50def main(config):
51 dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
52 config)
53
54 logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
55 model = build_model(config)
56 model.cuda()
57
58 logger.info(str(model))
59
60 model = torch.nn.parallel.DistributedDataParallel(
61 model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
62 model_without_ddp = model.module
63
64 n_parameters = sum(p.numel()
65 for p in model.parameters() if p.requires_grad)
66 logger.info(f"number of params: {n_parameters}")
67
68 optimizer = None
69 lr_scheduler = None
70
71 assert config.MODEL.RESUME
72 loss_scaler = NativeScalerWithGradNormCount()
73 load_checkpoint(config, model_without_ddp, optimizer,
74 lr_scheduler, loss_scaler, logger)
75 if not args.skip_eval and not args.check_saved_logits:
76 acc1, acc5, loss = validate(config, data_loader_val, model)
77 logger.info(
78 f"Accuracy of the network on the {len(dataset_val)} test images: top-1 acc: {acc1:.1f}%, top-5 acc: {acc5:.1f}%")
79
80 if args.check_saved_logits:
81 logger.info("Start checking logits")
82 else:
83 logger.info("Start saving logits")
84
85 start_time = time.time()
86 for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
87 dataset_train.set_epoch(epoch)
88 data_loader_train.sampler.set_epoch(epoch)
89
90 if args.check_saved_logits:
91 check_logits_one_epoch(
92 config, model, data_loader_train, epoch, mixup_fn=mixup_fn)
93 else:
94 save_logits_one_epoch(
95 config, model, data_loader_train, epoch, mixup_fn=mixup_fn)
96
97 total_time = time.time() - start_time
98 total_time_str = str(datetime.timedelta(seconds=int(total_time)))
99 logger.info('Saving logits time {}'.format(total_time_str))
100
101
102@torch.no_grad()

Callers 1

save_logits.pyFile · 0.70

Calls 9

build_loaderFunction · 0.90
build_modelFunction · 0.90
load_checkpointFunction · 0.90
check_logits_one_epochFunction · 0.85
save_logits_one_epochFunction · 0.85
formatMethod · 0.80
validateFunction · 0.70
set_epochMethod · 0.45

Tested by

no test coverage detected