MCPcopy
hub / github.com/jindongwang/transferlearning / train

Function train

code/deep/DaNN/main.py:30–66  ·  view source on GitHub ↗
(model, optimizer, epoch, data_src, data_tar)

Source from the content-addressed store, hash-verified

28
29
30def train(model, optimizer, epoch, data_src, data_tar):
31 total_loss_train = 0
32 criterion = nn.CrossEntropyLoss()
33 correct = 0
34 batch_j = 0
35 list_src, list_tar = list(enumerate(data_src)), list(enumerate(data_tar))
36 for batch_id, (data, target) in enumerate(data_src):
37 _, (x_tar, y_target) = list_tar[batch_j]
38 data, target = data.data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
39 x_tar, y_target = x_tar.view(-1, 28 * 28).to(DEVICE), y_target.to(DEVICE)
40 model.train()
41 y_src, x_src_mmd, x_tar_mmd = model(data, x_tar)
42
43 loss_c = criterion(y_src, target)
44 loss_mmd = mmd_loss(x_src_mmd, x_tar_mmd)
45 pred = y_src.data.max(1)[1] # get the index of the max log-probability
46 correct += pred.eq(target.data.view_as(pred)).cpu().sum()
47 loss = loss_c + LAMBDA * loss_mmd
48 optimizer.zero_grad()
49 loss.backward()
50 optimizer.step()
51 total_loss_train += loss.data
52 res_i = 'Epoch: [{}/{}], Batch: [{}/{}], loss: {:.6f}'.format(
53 epoch, N_EPOCH, batch_id + 1, len(data_src), loss.data
54 )
55 batch_j += 1
56 if batch_j >= len(list_tar):
57 batch_j = 0
58 total_loss_train /= len(data_src)
59 acc = correct * 100. / len(data_src.dataset)
60 res_e = 'Epoch: [{}/{}], training loss: {:.6f}, correct: [{}/{}], training accuracy: {:.4f}%'.format(
61 epoch, N_EPOCH, total_loss_train, correct, len(data_src.dataset), acc
62 )
63 tqdm.write(res_e)
64 log_train.write(res_e + '\n')
65 RESULT_TRAIN.append([epoch, total_loss_train, acc])
66 return model
67
68
69def test(model, data_tar, e):

Callers 1

main.pyFile · 0.70

Calls 6

sumMethod · 0.80
stepMethod · 0.80
mmd_lossFunction · 0.70
trainMethod · 0.45
backwardMethod · 0.45
writeMethod · 0.45

Tested by

no test coverage detected