| 65 | |
| 66 | |
| 67 | def train(args, model, device, train_loader, optimizer, epoch): |
| 68 | model.train() |
| 69 | for batch_idx, (data, target) in enumerate(train_loader): |
| 70 | data, target = data.to(device), target.to(device) |
| 71 | optimizer.zero_grad() |
| 72 | output = model(data) |
| 73 | loss = F.nll_loss(output, target) |
| 74 | loss.backward() |
| 75 | optimizer.step() |
| 76 | if batch_idx % args.log_interval == 0: |
| 77 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
| 78 | epoch, batch_idx * len(data), len(train_loader.dataset), |
| 79 | 100. * batch_idx / len(train_loader), loss.item())) |
| 80 | |
| 81 | |
| 82 | def test(model, device, test_loader): |