(
train_dataloader,
valid_dataloader,
model,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
| 189 | |
| 190 | |
| 191 | def train( |
| 192 | train_dataloader, |
| 193 | valid_dataloader, |
| 194 | model, |
| 195 | gpu_cache_miss_rate_fn, |
| 196 | cpu_cache_miss_rate_fn, |
| 197 | device, |
| 198 | ): |
| 199 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) |
| 200 | loss_fn = nn.CrossEntropyLoss() |
| 201 | |
| 202 | best_model = None |
| 203 | best_model_acc = 0 |
| 204 | best_model_epoch = -1 |
| 205 | |
| 206 | for epoch in range(args.epochs): |
| 207 | train_loss, train_acc, duration = train_helper( |
| 208 | train_dataloader, |
| 209 | model, |
| 210 | optimizer, |
| 211 | loss_fn, |
| 212 | gpu_cache_miss_rate_fn, |
| 213 | cpu_cache_miss_rate_fn, |
| 214 | device, |
| 215 | ) |
| 216 | val_acc = evaluate( |
| 217 | model, |
| 218 | valid_dataloader, |
| 219 | gpu_cache_miss_rate_fn, |
| 220 | cpu_cache_miss_rate_fn, |
| 221 | device, |
| 222 | ) |
| 223 | if val_acc > best_model_acc: |
| 224 | best_model_acc = val_acc |
| 225 | best_model = deepcopy(model.state_dict()) |
| 226 | best_model_epoch = epoch |
| 227 | print( |
| 228 | f"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, " |
| 229 | f"Approx. Train: {train_acc.item():.4f}, " |
| 230 | f"Approx. Val: {val_acc.item():.4f}, " |
| 231 | f"Time: {duration}s" |
| 232 | ) |
| 233 | if best_model_epoch + args.early_stopping_patience < epoch: |
| 234 | break |
| 235 | return best_model |
| 236 | |
| 237 | |
| 238 | @torch.no_grad() |
no test coverage detected