(dataloaders, model, optimizer, save_path)
| 270 | return dict_average(stats) |
| 271 | |
| 272 | def train(dataloaders, model, optimizer, save_path): |
| 273 | train_loader, val_loader, test_loader = dataloaders |
| 274 | best_loss = float("inf") |
| 275 | early_stop = 0 |
| 276 | log_json = [] |
| 277 | for epoch in range(args.start_epoch, args.epochs + 1): |
| 278 | early_stop += 1 |
| 279 | epoch_stats = collections.OrderedDict(epoch=epoch) |
| 280 | train_stats = train_epoch(train_loader, model, optimizer, epoch) |
| 281 | valid_stats = test(f"val_{epoch}", val_loader, model, visualize_sim_adapter=args.sim_adapter) |
| 282 | if best_loss > valid_stats["loss"]: # Save loss best model |
| 283 | best_loss = valid_stats["loss"] |
| 284 | torch_save(model, save_path) |
| 285 | early_stop = 0 |
| 286 | |
| 287 | test_stats = test(f"test_{epoch}", test_loader, model) |
| 288 | logging.warning( |
| 289 | f"Epoch: {epoch}, Iteration: {epoch * len(train_loader)}, " |
| 290 | + f"train loss: {train_stats['loss']:.4f}, dev loss: {valid_stats['loss']:.3f}, test loss: {test_stats['loss']:.3f}" |
| 291 | ) |
| 292 | |
| 293 | torch_save(model, f"{args.outdir}/snapshot.ep.{epoch}", optimizer=optimizer) |
| 294 | for key in sorted(list(set(list(train_stats.keys()) + list(test_stats.keys())))): |
| 295 | if not key.endswith("_lst"): |
| 296 | if key in train_stats: |
| 297 | epoch_stats[f"main/{key}"] = train_stats[key] |
| 298 | if key in valid_stats: |
| 299 | epoch_stats[f"validation/main/{key}"] = valid_stats[key] |
| 300 | if key in test_stats: |
| 301 | epoch_stats[f"test/main/{key}"] = test_stats[key] |
| 302 | |
| 303 | log_json.append(epoch_stats) |
| 304 | with open(f"{args.outdir}/log", "w") as f: |
| 305 | json.dump(log_json, f, |
| 306 | indent=4, |
| 307 | ensure_ascii=False, |
| 308 | separators=(",", ": "), |
| 309 | ) |
| 310 | logging.warning(f"Log saved at {args.outdir}/log") |
| 311 | |
| 312 | if args.patience > 0 and early_stop >= args.patience: |
| 313 | test_stats = test("test_best", test_loader, model, save_path) |
| 314 | logging.warning(f"=====Early stop! Final best test loss: {test_stats['loss']}") |
| 315 | break |
| 316 | |
| 317 | if __name__ == "__main__": |
| 318 | # 执行该命令运行4 GPU训练:CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=2 train.py |
no test coverage detected