MCPcopy
hub / github.com/jindongwang/transferlearning / test

Function test

code/ASR/CMatch/train.py:68–88  ·  view source on GitHub ↗
(dataloader, model, model_path=None)

Source from the content-addressed store, hash-verified

66
67
68def test(dataloader, model, model_path=None):
69 if model_path:
70 torch_load(model_path, model)
71 model.eval()
72 stats = collections.defaultdict(list)
73 for batch_idx, data in enumerate(dataloader):
74 logging.warning(f"Testing batch: {batch_idx+1}/{len(dataloader)}")
75 fbank, seq_lens, tokens = data
76 fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
77 with torch.no_grad():
78 loss = model(fbank, seq_lens, tokens)
79 stats["loss_lst"].append(loss.item())
80 if not hasattr(model, "module"):
81 if model.acc is not None:
82 stats["acc_lst"].append(model.acc)
83 model.acc = None
84 else:
85 if model.module.acc is not None:
86 stats["acc_lst"].append(model.module.acc)
87 model.module.acc = None
88 return dict_average(stats)
89
90def train(dataloaders, model, optimizer, save_path):
91 train_loader, val_loader, test_loader = dataloaders

Callers 1

trainFunction · 0.70

Calls 2

torch_loadFunction · 0.90
dict_averageFunction · 0.90

Tested by

no test coverage detected