MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / BertLoss

Class BertLoss

examples/tutorial/sequence_parallel/loss_func/bert_loss.py:9–28  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

7
8
9class BertLoss(nn.Module):
10 def forward(self, lm_loss, sop_logits, loss_mask, sentence_order):
11 lm_loss_ = lm_loss.float()
12 loss_mask = loss_mask.float()
13 loss_mask_sum = loss_mask.sum()
14 lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1))
15
16 lm_loss /= loss_mask_sum
17
18 torch.distributed.all_reduce(lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE))
19
20 if sop_logits is not None:
21 sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)
22 sop_loss = sop_loss.float()
23 loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE)
24 else:
25 sop_loss = None
26 loss = lm_loss
27
28 return loss

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…