MCPcopy
hub / github.com/microsoft/Cream / train_one_epoch

Function train_one_epoch

TinyCLIP/src/training/train.py:84–705  ·  view source on GitHub ↗
(model, data, epoch, optimizer, scaler, scheduler, scheduler_l0, args, tb_writer=None, start_iter=0, zs=None)

Source from the content-addressed store, hash-verified

82
83
84def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, scheduler_l0, args, tb_writer=None, start_iter=0, zs=None):
85
86 global NAN_LOSS_CNT
87
88 device = torch.device(args.device)
89 autocast = get_autocast(args.precision)
90
91 image_autocast = get_autocast(args.image_precision)
92 text_autocast = get_autocast(args.text_precision)
93 logit_autocast = get_autocast(args.logit_precision)
94
95 model.set_autocast(
96 image_autocast=image_autocast,
97 text_autocast=text_autocast,
98 logit_autocast=logit_autocast)
99
100 teacher_autocast = torch.cuda.amp.autocast
101
102 model_without_ddp = unwrap_model(model)
103
104 distillation = args.distillation
105 if distillation:
106 teacher_model = model_without_ddp.teacher[0]
107
108 model.train()
109 loss_kwargs = dict(
110 local_loss=args.local_loss,
111 gather_with_grad=args.gather_with_grad,
112 cache_labels=True,
113 rank=args.rank,
114 world_size=args.world_size,
115 use_horovod=args.horovod)
116
117 if start_iter == 0:
118 # set epoch in process safe manner via sampler or shared_epoch
119 data['train'].set_epoch(epoch)
120 dataloader = data['train'].dataloader
121
122 dataloader.device = args.device
123 if distillation:
124 soft_loss_fn = ClipSoftLoss(**loss_kwargs) # , ignore_diag=True)
125 else:
126 soft_loss_fn = None
127
128 hard_loss_fn = ClipLoss(**loss_kwargs)
129
130 dataloader, sampler = data['train'].dataloader, data['train'].sampler
131 if args.distributed and sampler is not None and start_iter == 0:
132 # [DO NOT REMOVE IT] it will call set_epoch even if sampler is not a DistributedSampler.
133 sampler.set_epoch(epoch)
134
135 num_batches_per_epoch = dataloader.num_batches
136 sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
137
138 loss_m = AverageMeter()
139 metrics = defaultdict(AverageMeter)
140 end = time.time()
141 batch_size = dataloader.batch_size

Callers 1

mainFunction · 0.90

Calls 15

updateMethod · 0.95
ClipSoftLossClass · 0.90
ClipLossClass · 0.90
AverageMeterClass · 0.90
build_optimizerFunction · 0.90
cosine_lr_start_nowarmupFunction · 0.90
plotFunction · 0.90
get_autocastFunction · 0.85
check_last_batchFunction · 0.85
infer_chunksFunction · 0.85
is_masterFunction · 0.85
get_state_dictFunction · 0.85

Tested by

no test coverage detected