MCPcopy Index your code
hub / github.com/microsoft/AI-System / train

Function train

Labs/BasicLabs/Lab5/pytorch_mnist_basic.py:37–51  ·  view source on GitHub ↗
(args, model, device, train_loader, optimizer, epoch)

Source from the content-addressed store, hash-verified

35
36
37def train(args, model, device, train_loader, optimizer, epoch):
38 model.train()
39 for batch_idx, (data, target) in enumerate(train_loader):
40 data, target = data.to(device), target.to(device)
41 optimizer.zero_grad()
42 output = model(data)
43 loss = F.nll_loss(output, target)
44 loss.backward()
45 optimizer.step()
46 if batch_idx % args.log_interval == 0:
47 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
48 epoch, batch_idx * len(data), len(train_loader.dataset),
49 100. * batch_idx / len(train_loader), loss.item()))
50 if args.dry_run:
51 break
52
53
54def test(model, device, test_loader):

Callers 1

mainFunction · 0.70

Calls 1

backwardMethod · 0.45

Tested by

no test coverage detected