| 35 | |
| 36 | |
| 37 | def train(args, model, device, train_loader, optimizer, epoch): |
| 38 | model.train() |
| 39 | for batch_idx, (data, target) in enumerate(train_loader): |
| 40 | data, target = data.to(device), target.to(device) |
| 41 | optimizer.zero_grad() |
| 42 | output = model(data) |
| 43 | loss = F.nll_loss(output, target) |
| 44 | loss.backward() |
| 45 | optimizer.step() |
| 46 | if batch_idx % args.log_interval == 0: |
| 47 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
| 48 | epoch, batch_idx * len(data), len(train_loader.dataset), |
| 49 | 100. * batch_idx / len(train_loader), loss.item())) |
| 50 | if args.dry_run: |
| 51 | break |
| 52 | |
| 53 | |
| 54 | def test(model, device, test_loader): |