()
| 278 | |
| 279 | |
| 280 | def main(): |
| 281 | best_acc = 0 |
| 282 | opt = parse_option() |
| 283 | |
| 284 | # build data loader |
| 285 | train_loader, val_loader = set_loader(opt) |
| 286 | |
| 287 | # build model and criterion |
| 288 | model, criterion = set_model(opt) |
| 289 | |
| 290 | # build optimizer |
| 291 | optimizer = set_optimizer(opt, model) |
| 292 | |
| 293 | # tensorboard |
| 294 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) |
| 295 | |
| 296 | # training routine |
| 297 | for epoch in range(1, opt.epochs + 1): |
| 298 | adjust_learning_rate(opt, optimizer, epoch) |
| 299 | |
| 300 | # train for one epoch |
| 301 | time1 = time.time() |
| 302 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt) |
| 303 | time2 = time.time() |
| 304 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) |
| 305 | |
| 306 | # tensorboard logger |
| 307 | logger.log_value('train_loss', loss, epoch) |
| 308 | logger.log_value('train_acc', train_acc, epoch) |
| 309 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) |
| 310 | |
| 311 | # evaluation |
| 312 | loss, val_acc = validate(val_loader, model, criterion, opt) |
| 313 | logger.log_value('val_loss', loss, epoch) |
| 314 | logger.log_value('val_acc', val_acc, epoch) |
| 315 | |
| 316 | if val_acc > best_acc: |
| 317 | best_acc = val_acc |
| 318 | |
| 319 | if epoch % opt.save_freq == 0: |
| 320 | save_file = os.path.join( |
| 321 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) |
| 322 | save_model(model, optimizer, opt, epoch, save_file) |
| 323 | |
| 324 | # save the last model |
| 325 | save_file = os.path.join( |
| 326 | opt.save_folder, 'last.pth') |
| 327 | save_model(model, optimizer, opt, opt.epochs, save_file) |
| 328 | |
| 329 | print('best accuracy: {:.2f}'.format(best_acc)) |
| 330 | |
| 331 | |
| 332 | if __name__ == '__main__': |
no test coverage detected