()
| 227 | |
| 228 | |
| 229 | def main(): |
| 230 | best_acc = 0 |
| 231 | opt = parse_option() |
| 232 | |
| 233 | # build data loader |
| 234 | train_loader, val_loader = set_loader(opt) |
| 235 | |
| 236 | # build model and criterion |
| 237 | model, classifier, criterion = set_model(opt) |
| 238 | |
| 239 | # build optimizer |
| 240 | optimizer = set_optimizer(opt, classifier) |
| 241 | |
| 242 | # training routine |
| 243 | for epoch in range(1, opt.epochs + 1): |
| 244 | adjust_learning_rate(opt, optimizer, epoch) |
| 245 | |
| 246 | # train for one epoch |
| 247 | time1 = time.time() |
| 248 | loss, acc = train(train_loader, model, classifier, criterion, |
| 249 | optimizer, epoch, opt) |
| 250 | time2 = time.time() |
| 251 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format( |
| 252 | epoch, time2 - time1, acc)) |
| 253 | |
| 254 | # eval for one epoch |
| 255 | loss, val_acc = validate(val_loader, model, classifier, criterion, opt) |
| 256 | if val_acc > best_acc: |
| 257 | best_acc = val_acc |
| 258 | |
| 259 | print('best accuracy: {:.2f}'.format(best_acc)) |
| 260 | |
| 261 | |
| 262 | if __name__ == '__main__': |
no test coverage detected