| 28 | |
| 29 | |
| 30 | def train(model, optimizer, epoch, data_src, data_tar): |
| 31 | total_loss_train = 0 |
| 32 | criterion = nn.CrossEntropyLoss() |
| 33 | correct = 0 |
| 34 | batch_j = 0 |
| 35 | list_src, list_tar = list(enumerate(data_src)), list(enumerate(data_tar)) |
| 36 | for batch_id, (data, target) in enumerate(data_src): |
| 37 | _, (x_tar, y_target) = list_tar[batch_j] |
| 38 | data, target = data.data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE) |
| 39 | x_tar, y_target = x_tar.view(-1, 28 * 28).to(DEVICE), y_target.to(DEVICE) |
| 40 | model.train() |
| 41 | y_src, x_src_mmd, x_tar_mmd = model(data, x_tar) |
| 42 | |
| 43 | loss_c = criterion(y_src, target) |
| 44 | loss_mmd = mmd_loss(x_src_mmd, x_tar_mmd) |
| 45 | pred = y_src.data.max(1)[1] # get the index of the max log-probability |
| 46 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() |
| 47 | loss = loss_c + LAMBDA * loss_mmd |
| 48 | optimizer.zero_grad() |
| 49 | loss.backward() |
| 50 | optimizer.step() |
| 51 | total_loss_train += loss.data |
| 52 | res_i = 'Epoch: [{}/{}], Batch: [{}/{}], loss: {:.6f}'.format( |
| 53 | epoch, N_EPOCH, batch_id + 1, len(data_src), loss.data |
| 54 | ) |
| 55 | batch_j += 1 |
| 56 | if batch_j >= len(list_tar): |
| 57 | batch_j = 0 |
| 58 | total_loss_train /= len(data_src) |
| 59 | acc = correct * 100. / len(data_src.dataset) |
| 60 | res_e = 'Epoch: [{}/{}], training loss: {:.6f}, correct: [{}/{}], training accuracy: {:.4f}%'.format( |
| 61 | epoch, N_EPOCH, total_loss_train, correct, len(data_src.dataset), acc |
| 62 | ) |
| 63 | tqdm.write(res_e) |
| 64 | log_train.write(res_e + '\n') |
| 65 | RESULT_TRAIN.append([epoch, total_loss_train, acc]) |
| 66 | return model |
| 67 | |
| 68 | |
| 69 | def test(model, data_tar, e): |