(model, rank, world_size, val_loader)
| 69 | |
| 70 | |
| 71 | def validation(model, rank, world_size, val_loader): |
| 72 | model.eval() |
| 73 | correct = 0 |
| 74 | local_rank = int(os.environ['LOCAL_RANK']) |
| 75 | fsdp_loss = torch.zeros(2).to(local_rank) |
| 76 | if rank == 0: |
| 77 | inner_pbar = tqdm.tqdm( |
| 78 | range(len(val_loader)), colour="green", desc="Validation Epoch" |
| 79 | ) |
| 80 | with torch.no_grad(): |
| 81 | for batch in val_loader: |
| 82 | for key in batch.keys(): |
| 83 | batch[key] = batch[key].to(local_rank) |
| 84 | output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"]) |
| 85 | fsdp_loss[0] += output["loss"].item() # sum up batch loss |
| 86 | fsdp_loss[1] += len(batch) |
| 87 | |
| 88 | if rank==0: |
| 89 | inner_pbar.update(1) |
| 90 | |
| 91 | dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) |
| 92 | val_loss = fsdp_loss[0] / fsdp_loss[1] |
| 93 | if rank == 0: |
| 94 | inner_pbar.close() |
| 95 | print(f"Validation Loss: {val_loss:.4f}") |
| 96 | return val_loss |
| 97 | |
| 98 | |
| 99 | def setup_model(model_name): |
no test coverage detected