MCPcopy Index your code
hub / github.com/patrickloeber/pytorchTutorial / train_model

Function train_model

15_transfer_learning.py:63–128  ·  view source on GitHub ↗
(model, criterion, optimizer, scheduler, num_epochs=25)

Source from the content-addressed store, hash-verified

61imshow(out, title=[class_names[x] for x in classes])
62
63def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
64 since = time.time()
65
66 best_model_wts = copy.deepcopy(model.state_dict())
67 best_acc = 0.0
68
69 for epoch in range(num_epochs):
70 print('Epoch {}/{}'.format(epoch, num_epochs - 1))
71 print('-' * 10)
72
73 # Each epoch has a training and validation phase
74 for phase in ['train', 'val']:
75 if phase == 'train':
76 model.train() # Set model to training mode
77 else:
78 model.eval() # Set model to evaluate mode
79
80 running_loss = 0.0
81 running_corrects = 0
82
83 # Iterate over data.
84 for inputs, labels in dataloaders[phase]:
85 inputs = inputs.to(device)
86 labels = labels.to(device)
87
88 # forward
89 # track history if only in train
90 with torch.set_grad_enabled(phase == 'train'):
91 outputs = model(inputs)
92 _, preds = torch.max(outputs, 1)
93 loss = criterion(outputs, labels)
94
95 # backward + optimize only if in training phase
96 if phase == 'train':
97 optimizer.zero_grad()
98 loss.backward()
99 optimizer.step()
100
101 # statistics
102 running_loss += loss.item() * inputs.size(0)
103 running_corrects += torch.sum(preds == labels.data)
104
105 if phase == 'train':
106 scheduler.step()
107
108 epoch_loss = running_loss / dataset_sizes[phase]
109 epoch_acc = running_corrects.double() / dataset_sizes[phase]
110
111 print('{} Loss: {:.4f} Acc: {:.4f}'.format(
112 phase, epoch_loss, epoch_acc))
113
114 # deep copy the model
115 if phase == 'val' and epoch_acc > best_acc:
116 best_acc = epoch_acc
117 best_model_wts = copy.deepcopy(model.state_dict())
118
119 print()
120

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected