MCPcopy Index your code
hub / github.com/pytorch/examples / train

Function train

vae/main.py:92–110  ·  view source on GitHub ↗
(epoch)

Source from the content-addressed store, hash-verified

90
91
92def train(epoch):
93 model.train()
94 train_loss = 0
95 for batch_idx, (data, _) in enumerate(train_loader):
96 data = data.to(device)
97 optimizer.zero_grad()
98 recon_batch, mu, logvar = model(data)
99 loss = loss_function(recon_batch, data, mu, logvar)
100 loss.backward()
101 train_loss += loss.item()
102 optimizer.step()
103 if batch_idx % args.log_interval == 0:
104 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
105 epoch, batch_idx * len(data), len(train_loader.dataset),
106 100. * batch_idx / len(train_loader),
107 loss.item() / len(data)))
108
109 print('====> Epoch: {} Average loss: {:.4f}'.format(
110 epoch, train_loss / len(train_loader.dataset)))
111
112
113def test(epoch):

Callers 1

main.pyFile · 0.70

Calls 2

loss_functionFunction · 0.85
trainMethod · 0.45

Tested by

no test coverage detected