(
model,
dataloader,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
| 278 | |
| 279 | @torch.no_grad() |
| 280 | def evaluate( |
| 281 | model, |
| 282 | dataloader, |
| 283 | gpu_cache_miss_rate_fn, |
| 284 | cpu_cache_miss_rate_fn, |
| 285 | device, |
| 286 | ): |
| 287 | model.eval() |
| 288 | total_correct = torch.zeros(1, dtype=torch.float64, device=device) |
| 289 | total_samples = 0 |
| 290 | val_dataloader_tqdm = tqdm(dataloader, "Evaluating") |
| 291 | for step, minibatch in enumerate(val_dataloader_tqdm): |
| 292 | num_correct, num_samples = evaluate_step(minibatch, model) |
| 293 | total_correct += num_correct |
| 294 | total_samples += num_samples |
| 295 | if step % 25 == 0: |
| 296 | val_dataloader_tqdm.set_postfix( |
| 297 | { |
| 298 | "num_nodes": minibatch.node_ids().size(0), |
| 299 | "gpu_cache_miss": gpu_cache_miss_rate_fn(), |
| 300 | "cpu_cache_miss": cpu_cache_miss_rate_fn(), |
| 301 | } |
| 302 | ) |
| 303 | |
| 304 | return total_correct / total_samples |
| 305 | |
| 306 | |
| 307 | def parse_args(): |
no test coverage detected