MCPcopy
hub / github.com/microsoft/AI-System / train

Function train

Labs/BasicLabs/Lab1/mnist_basic.py:67–79  ·  view source on GitHub ↗
(args, model, device, train_loader, optimizer, epoch)

Source from the content-addressed store, hash-verified

65
66
67def train(args, model, device, train_loader, optimizer, epoch):
68 model.train()
69 for batch_idx, (data, target) in enumerate(train_loader):
70 data, target = data.to(device), target.to(device)
71 optimizer.zero_grad()
72 output = model(data)
73 loss = F.nll_loss(output, target)
74 loss.backward()
75 optimizer.step()
76 if batch_idx % args.log_interval == 0:
77 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
78 epoch, batch_idx * len(data), len(train_loader.dataset),
79 100. * batch_idx / len(train_loader), loss.item()))
80
81
82def test(model, device, test_loader):

Callers 1

mainFunction · 0.70

Calls 1

backwardMethod · 0.45

Tested by

no test coverage detected