(epoch)
| 90 | |
| 91 | |
| 92 | def train(epoch): |
| 93 | model.train() |
| 94 | train_loss = 0 |
| 95 | for batch_idx, (data, _) in enumerate(train_loader): |
| 96 | data = data.to(device) |
| 97 | optimizer.zero_grad() |
| 98 | recon_batch, mu, logvar = model(data) |
| 99 | loss = loss_function(recon_batch, data, mu, logvar) |
| 100 | loss.backward() |
| 101 | train_loss += loss.item() |
| 102 | optimizer.step() |
| 103 | if batch_idx % args.log_interval == 0: |
| 104 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
| 105 | epoch, batch_idx * len(data), len(train_loader.dataset), |
| 106 | 100. * batch_idx / len(train_loader), |
| 107 | loss.item() / len(data))) |
| 108 | |
| 109 | print('====> Epoch: {} Average loss: {:.4f}'.format( |
| 110 | epoch, train_loss / len(train_loader.dataset))) |
| 111 | |
| 112 | |
| 113 | def test(epoch): |
no test coverage detected