(dataloader, model, model_path=None)
| 66 | |
| 67 | |
| 68 | def test(dataloader, model, model_path=None): |
| 69 | if model_path: |
| 70 | torch_load(model_path, model) |
| 71 | model.eval() |
| 72 | stats = collections.defaultdict(list) |
| 73 | for batch_idx, data in enumerate(dataloader): |
| 74 | logging.warning(f"Testing batch: {batch_idx+1}/{len(dataloader)}") |
| 75 | fbank, seq_lens, tokens = data |
| 76 | fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda() |
| 77 | with torch.no_grad(): |
| 78 | loss = model(fbank, seq_lens, tokens) |
| 79 | stats["loss_lst"].append(loss.item()) |
| 80 | if not hasattr(model, "module"): |
| 81 | if model.acc is not None: |
| 82 | stats["acc_lst"].append(model.acc) |
| 83 | model.acc = None |
| 84 | else: |
| 85 | if model.module.acc is not None: |
| 86 | stats["acc_lst"].append(model.module.acc) |
| 87 | model.module.acc = None |
| 88 | return dict_average(stats) |
| 89 | |
| 90 | def train(dataloaders, model, optimizer, save_path): |
| 91 | train_loader, val_loader, test_loader = dataloaders |
no test coverage detected