AccPerplex module for calculating model's accuracy and perplexity metrics. Args: device: The GPU device. tp_pg: The tensor parallel process group. dp_pg: The data parallel process group. tokenizer: For calculating BPB. dataset_types (List[str]): Vari
| 10 | |
| 11 | |
| 12 | class AccPerplex: |
| 13 | """ |
| 14 | AccPerplex module for calculating model's accuracy and perplexity metrics. |
| 15 | |
| 16 | Args: |
| 17 | device: The GPU device. |
| 18 | tp_pg: The tensor parallel process group. |
| 19 | dp_pg: The data parallel process group. |
| 20 | tokenizer: For calculating BPB. |
| 21 | dataset_types (List[str]): Various data types that will be used in the current training process, |
| 22 | such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified |
| 23 | in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids(). |
| 24 | """ |
| 25 | |
| 26 | def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None): |
| 27 | self.device = device |
| 28 | self.right = torch.Tensor([0]).to(device=device) |
| 29 | self.total = torch.Tensor([0]).to(device=device) |
| 30 | self.total_log_probs = torch.Tensor([0]).to(device=device) |
| 31 | self.tp_pg = tp_pg |
| 32 | self.dp_pg = dp_pg |
| 33 | self.tp_local_rank = torch.distributed.get_rank(self.tp_pg) |
| 34 | self.tokenizer = tokenizer |
| 35 | self.total_bytes = torch.Tensor([0]).to(device=device).view(1) |
| 36 | self.batch_shift = 0 |
| 37 | self.type_ids = None |
| 38 | if dataset_types is not None: |
| 39 | self.dataset_types = dataset_types |
| 40 | self.total_type_count = len(dataset_types) |
| 41 | self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device) |
| 42 | self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device) |
| 43 | |
| 44 | self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types) |
| 45 | |
| 46 | def set_current_type_ids(self, type_ids: torch.Tensor): |
| 47 | self.batch_shift = 0 |
| 48 | self.type_ids = type_ids.cuda() |
| 49 | |
| 50 | def __call__(self, logits, labels): |
| 51 | return self.update(logits, labels, type_ids=self.type_ids) |
| 52 | |
| 53 | def update(self, logits, labels, type_ids=None): |
| 54 | if gpc.config.model.use_flash_attn: |
| 55 | micro_bsz = labels.size(0) |
| 56 | else: |
| 57 | micro_bsz = 1 |
| 58 | if type_ids is not None: |
| 59 | type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1) |
| 60 | self.batch_shift += 1 |
| 61 | self.loss_with_type_id.update(logits, labels, type_ids) |
| 62 | |
| 63 | with torch.no_grad(): |
| 64 | if isinstance(logits, (list, tuple)): |
| 65 | logits = logits[0] |
| 66 | |
| 67 | logits = logits.detach().clone() |
| 68 | labels = labels.detach().clone() |
| 69 |
no outgoing calls
no test coverage detected