MCPcopy Index your code
hub / github.com/jindongwang/transferlearning / train

Function train

code/ASR/Adapter/train.py:272–315  ·  view source on GitHub ↗
(dataloaders, model, optimizer, save_path)

Source from the content-addressed store, hash-verified

270 return dict_average(stats)
271
272def train(dataloaders, model, optimizer, save_path):
273 train_loader, val_loader, test_loader = dataloaders
274 best_loss = float("inf")
275 early_stop = 0
276 log_json = []
277 for epoch in range(args.start_epoch, args.epochs + 1):
278 early_stop += 1
279 epoch_stats = collections.OrderedDict(epoch=epoch)
280 train_stats = train_epoch(train_loader, model, optimizer, epoch)
281 valid_stats = test(f"val_{epoch}", val_loader, model, visualize_sim_adapter=args.sim_adapter)
282 if best_loss > valid_stats["loss"]: # Save loss best model
283 best_loss = valid_stats["loss"]
284 torch_save(model, save_path)
285 early_stop = 0
286
287 test_stats = test(f"test_{epoch}", test_loader, model)
288 logging.warning(
289 f"Epoch: {epoch}, Iteration: {epoch * len(train_loader)}, "
290 + f"train loss: {train_stats['loss']:.4f}, dev loss: {valid_stats['loss']:.3f}, test loss: {test_stats['loss']:.3f}"
291 )
292
293 torch_save(model, f"{args.outdir}/snapshot.ep.{epoch}", optimizer=optimizer)
294 for key in sorted(list(set(list(train_stats.keys()) + list(test_stats.keys())))):
295 if not key.endswith("_lst"):
296 if key in train_stats:
297 epoch_stats[f"main/{key}"] = train_stats[key]
298 if key in valid_stats:
299 epoch_stats[f"validation/main/{key}"] = valid_stats[key]
300 if key in test_stats:
301 epoch_stats[f"test/main/{key}"] = test_stats[key]
302
303 log_json.append(epoch_stats)
304 with open(f"{args.outdir}/log", "w") as f:
305 json.dump(log_json, f,
306 indent=4,
307 ensure_ascii=False,
308 separators=(",", ": "),
309 )
310 logging.warning(f"Log saved at {args.outdir}/log")
311
312 if args.patience > 0 and early_stop >= args.patience:
313 test_stats = test("test_best", test_loader, model, save_path)
314 logging.warning(f"=====Early stop! Final best test loss: {test_stats['loss']}")
315 break
316
317if __name__ == "__main__":
318 # 执行该命令运行4 GPU训练:CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=2 train.py

Callers 1

train.pyFile · 0.70

Calls 3

torch_saveFunction · 0.90
train_epochFunction · 0.70
testFunction · 0.70

Tested by

no test coverage detected