(self, logits, labels, type_ids=None)
| 51 | return self.update(logits, labels, type_ids=self.type_ids) |
| 52 | |
| 53 | def update(self, logits, labels, type_ids=None): |
| 54 | if gpc.config.model.use_flash_attn: |
| 55 | micro_bsz = labels.size(0) |
| 56 | else: |
| 57 | micro_bsz = 1 |
| 58 | if type_ids is not None: |
| 59 | type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1) |
| 60 | self.batch_shift += 1 |
| 61 | self.loss_with_type_id.update(logits, labels, type_ids) |
| 62 | |
| 63 | with torch.no_grad(): |
| 64 | if isinstance(logits, (list, tuple)): |
| 65 | logits = logits[0] |
| 66 | |
| 67 | logits = logits.detach().clone() |
| 68 | labels = labels.detach().clone() |
| 69 | |
| 70 | if self.tokenizer: # need to calculate bits per bytes |
| 71 | sequences = self.tokenizer.decode_ids(labels.tolist()) |
| 72 | self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences)) |
| 73 | |
| 74 | shift_logits = logits.view(-1, logits.size(-1)) |
| 75 | shift_labels = labels.view(-1) |
| 76 | # There is a shift according to the current rank, because the logits are split |
| 77 | pred_shift = self.tp_local_rank * logits.shape[-1] |
| 78 | |
| 79 | logits_max = torch.max(shift_logits, dim=-1)[0] |
| 80 | torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg) |
| 81 | # Determine whether the maximum value of the current local tensor is the global maximum value |
| 82 | logits_global = logits_max == torch.max(shift_logits, dim=-1)[0] |
| 83 | |
| 84 | corrects = torch.logical_and( |
| 85 | (shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global |
| 86 | ).long() |
| 87 | mask = shift_labels.ne(-100).long() |
| 88 | if hasattr(self, "total_type_count"): |
| 89 | ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum") |
| 90 | token_num_type = scatter(mask, type_ids, dim=0, reduce="sum") |
| 91 | if len(ds_acc) < self.total_type_count: |
| 92 | ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))]) |
| 93 | token_num_type = torch.cat( |
| 94 | [token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))] |
| 95 | ) |
| 96 | self.ds_tokens += token_num_type |
| 97 | sync_tensor = ds_acc |
| 98 | torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) |
| 99 | self.ds_right += sync_tensor.view(-1) |
| 100 | |
| 101 | acc = corrects.sum() |
| 102 | torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) |
| 103 | self.right += acc # Masked_fill is not needed here because -100 is not available anyway |
| 104 | self.total += mask.sum() |
| 105 | |
| 106 | # Subtract the maximum value. |
| 107 | shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1)) |
| 108 | |
| 109 | # Get the partition's vocab indecies |
| 110 | partition_vocab_size = shift_logits.size()[-1] |
no test coverage detected