| 7 | |
| 8 | |
| 9 | class BertLoss(nn.Module): |
| 10 | def forward(self, lm_loss, sop_logits, loss_mask, sentence_order): |
| 11 | lm_loss_ = lm_loss.float() |
| 12 | loss_mask = loss_mask.float() |
| 13 | loss_mask_sum = loss_mask.sum() |
| 14 | lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) |
| 15 | |
| 16 | lm_loss /= loss_mask_sum |
| 17 | |
| 18 | torch.distributed.all_reduce(lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE)) |
| 19 | |
| 20 | if sop_logits is not None: |
| 21 | sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) |
| 22 | sop_loss = sop_loss.float() |
| 23 | loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) |
| 24 | else: |
| 25 | sop_loss = None |
| 26 | loss = lm_loss |
| 27 | |
| 28 | return loss |
no outgoing calls
no test coverage detected
searching dependent graphs…