MCPcopy Index your code
hub / github.com/pytorch/examples / train

Function train

distributed/FSDP/utils/train_utils.py:35–68  ·  view source on GitHub ↗
(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None)

Source from the content-addressed store, hash-verified

33 return metric_num
34
35def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
36 model.train()
37 local_rank = int(os.environ['LOCAL_RANK'])
38 fsdp_loss = torch.zeros(2).to(local_rank)
39
40 if sampler:
41 sampler.set_epoch(epoch)
42 if rank==0:
43 inner_pbar = tqdm.tqdm(
44 range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
45 )
46 for batch in train_loader:
47 for key in batch.keys():
48 batch[key] = batch[key].to(local_rank)
49 optimizer.zero_grad()
50 output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
51 loss = output["loss"]
52 loss.backward()
53 optimizer.step()
54 fsdp_loss[0] += loss.item()
55 fsdp_loss[1] += len(batch)
56 if rank==0:
57 inner_pbar.update(1)
58
59 dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
60 train_accuracy = fsdp_loss[0] / fsdp_loss[1]
61
62
63 if rank == 0:
64 inner_pbar.close()
65 print(
66 f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
67 )
68 return train_accuracy
69
70
71def validation(model, rank, world_size, val_loader):

Callers 1

fsdp_mainFunction · 0.90

Calls 3

updateMethod · 0.80
all_reduceMethod · 0.80
trainMethod · 0.45

Tested by

no test coverage detected