MCPcopy
hub / github.com/HobbitLong/SupContrast / main

Function main

main_ce.py:280–329  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

278
279
280def main():
281 best_acc = 0
282 opt = parse_option()
283
284 # build data loader
285 train_loader, val_loader = set_loader(opt)
286
287 # build model and criterion
288 model, criterion = set_model(opt)
289
290 # build optimizer
291 optimizer = set_optimizer(opt, model)
292
293 # tensorboard
294 logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
295
296 # training routine
297 for epoch in range(1, opt.epochs + 1):
298 adjust_learning_rate(opt, optimizer, epoch)
299
300 # train for one epoch
301 time1 = time.time()
302 loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt)
303 time2 = time.time()
304 print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
305
306 # tensorboard logger
307 logger.log_value('train_loss', loss, epoch)
308 logger.log_value('train_acc', train_acc, epoch)
309 logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
310
311 # evaluation
312 loss, val_acc = validate(val_loader, model, criterion, opt)
313 logger.log_value('val_loss', loss, epoch)
314 logger.log_value('val_acc', val_acc, epoch)
315
316 if val_acc > best_acc:
317 best_acc = val_acc
318
319 if epoch % opt.save_freq == 0:
320 save_file = os.path.join(
321 opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
322 save_model(model, optimizer, opt, epoch, save_file)
323
324 # save the last model
325 save_file = os.path.join(
326 opt.save_folder, 'last.pth')
327 save_model(model, optimizer, opt, opt.epochs, save_file)
328
329 print('best accuracy: {:.2f}'.format(best_acc))
330
331
332if __name__ == '__main__':

Callers 1

main_ce.pyFile · 0.70

Calls 8

set_optimizerFunction · 0.90
adjust_learning_rateFunction · 0.90
save_modelFunction · 0.90
parse_optionFunction · 0.70
set_loaderFunction · 0.70
set_modelFunction · 0.70
trainFunction · 0.70
validateFunction · 0.70

Tested by

no test coverage detected