MCPcopy Index your code
hub / github.com/pytorch/examples / validation

Function validation

distributed/FSDP/utils/train_utils.py:71–96  ·  view source on GitHub ↗
(model, rank, world_size, val_loader)

Source from the content-addressed store, hash-verified

69
70
71def validation(model, rank, world_size, val_loader):
72 model.eval()
73 correct = 0
74 local_rank = int(os.environ['LOCAL_RANK'])
75 fsdp_loss = torch.zeros(2).to(local_rank)
76 if rank == 0:
77 inner_pbar = tqdm.tqdm(
78 range(len(val_loader)), colour="green", desc="Validation Epoch"
79 )
80 with torch.no_grad():
81 for batch in val_loader:
82 for key in batch.keys():
83 batch[key] = batch[key].to(local_rank)
84 output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
85 fsdp_loss[0] += output["loss"].item() # sum up batch loss
86 fsdp_loss[1] += len(batch)
87
88 if rank==0:
89 inner_pbar.update(1)
90
91 dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
92 val_loss = fsdp_loss[0] / fsdp_loss[1]
93 if rank == 0:
94 inner_pbar.close()
95 print(f"Validation Loss: {val_loss:.4f}")
96 return val_loss
97
98
99def setup_model(model_name):

Callers 1

fsdp_mainFunction · 0.90

Calls 2

updateMethod · 0.80
all_reduceMethod · 0.80

Tested by

no test coverage detected