MCPcopy
hub / github.com/microsoft/Cream / retrain_warmup

Function retrain_warmup

CDARTS/benchmark201/core/search_function.py:133–196  ·  view source on GitHub ↗
(valid_loader, model, optimizer, epoch, writer, logger, super_flag, retrain_epochs, config)

Source from the content-addressed store, hash-verified

131
132
133def retrain_warmup(valid_loader, model, optimizer, epoch, writer, logger, super_flag, retrain_epochs, config):
134
135 device = torch.device("cuda")
136 criterion = nn.CrossEntropyLoss().to(device)
137 top1 = utils.AverageMeter()
138 top5 = utils.AverageMeter()
139 losses = utils.AverageMeter()
140
141 step_num = len(valid_loader)
142 step_num = int(step_num * config.sample_ratio)
143
144 cur_step = epoch*step_num
145 cur_lr = optimizer.param_groups[0]['lr']
146 if config.local_rank == 0:
147 logger.info("Warmup Epoch {} LR {:.3f}".format(epoch+1, cur_lr))
148 writer.add_scalar('warmup/lr', cur_lr, cur_step)
149
150 model.train()
151
152 for step, (val_X, val_y) in enumerate(valid_loader):
153 if step > step_num:
154 break
155
156 val_X, val_y = val_X.to(device, non_blocking=True), val_y.to(device, non_blocking=True)
157 N = val_X.size(0)
158
159 optimizer.zero_grad()
160 logits_main, _ = model(val_X, super_flag=super_flag)
161 loss = criterion(logits_main, val_y)
162 loss.backward()
163
164 nn.utils.clip_grad_norm_(model.module.parameters(), config.w_grad_clip)
165 optimizer.step()
166
167 prec1, prec5 = utils.accuracy(logits_main, val_y, topk=(1, 5))
168 if config.distributed:
169 reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
170 prec1 = utils.reduce_tensor(prec1, config.world_size)
171 prec5 = utils.reduce_tensor(prec5, config.world_size)
172
173 else:
174 reduced_loss = loss.data
175
176 losses.update(reduced_loss.item(), N)
177 top1.update(prec1.item(), N)
178 top5.update(prec5.item(), N)
179
180 torch.cuda.synchronize()
181 if config.local_rank == 0 and (step % config.print_freq == 0 or step == step_num):
182 logger.info(
183 "Warmup: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} "
184 "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
185 epoch+1, retrain_epochs, step,
186 step_num, losses=losses, top1=top1, top5=top5))
187
188 if config.local_rank == 0:
189 writer.add_scalar('retrain/loss', reduced_loss.item(), cur_step)
190 writer.add_scalar('retrain/top1', prec1.item(), cur_step)

Callers 1

mainFunction · 0.90

Calls 8

updateMethod · 0.95
toMethod · 0.80
formatMethod · 0.80
zero_gradMethod · 0.80
trainMethod · 0.45
sizeMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected