()
| 253 | |
| 254 | |
| 255 | def main(): |
| 256 | opt = parse_option() |
| 257 | |
| 258 | # build data loader |
| 259 | train_loader = set_loader(opt) |
| 260 | |
| 261 | # build model and criterion |
| 262 | model, criterion = set_model(opt) |
| 263 | |
| 264 | # build optimizer |
| 265 | optimizer = set_optimizer(opt, model) |
| 266 | |
| 267 | # tensorboard |
| 268 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) |
| 269 | |
| 270 | # training routine |
| 271 | for epoch in range(1, opt.epochs + 1): |
| 272 | adjust_learning_rate(opt, optimizer, epoch) |
| 273 | |
| 274 | # train for one epoch |
| 275 | time1 = time.time() |
| 276 | loss = train(train_loader, model, criterion, optimizer, epoch, opt) |
| 277 | time2 = time.time() |
| 278 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) |
| 279 | |
| 280 | # tensorboard logger |
| 281 | logger.log_value('loss', loss, epoch) |
| 282 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) |
| 283 | |
| 284 | if epoch % opt.save_freq == 0: |
| 285 | save_file = os.path.join( |
| 286 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) |
| 287 | save_model(model, optimizer, opt, epoch, save_file) |
| 288 | |
| 289 | # save the last model |
| 290 | save_file = os.path.join( |
| 291 | opt.save_folder, 'last.pth') |
| 292 | save_model(model, optimizer, opt, opt.epochs, save_file) |
| 293 | |
| 294 | |
| 295 | if __name__ == '__main__': |
no test coverage detected