MCPcopy
hub / github.com/InternLM/InternLM / AccPerplex

Class AccPerplex

internlm/model/metrics.py:12–182  ·  view source on GitHub ↗

AccPerplex module for calculating model's accuracy and perplexity metrics. Args: device: The GPU device. tp_pg: The tensor parallel process group. dp_pg: The data parallel process group. tokenizer: For calculating BPB. dataset_types (List[str]): Vari

Source from the content-addressed store, hash-verified

10
11
12class AccPerplex:
13 """
14 AccPerplex module for calculating model's accuracy and perplexity metrics.
15
16 Args:
17 device: The GPU device.
18 tp_pg: The tensor parallel process group.
19 dp_pg: The data parallel process group.
20 tokenizer: For calculating BPB.
21 dataset_types (List[str]): Various data types that will be used in the current training process,
22 such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
23 in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
24 """
25
26 def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
27 self.device = device
28 self.right = torch.Tensor([0]).to(device=device)
29 self.total = torch.Tensor([0]).to(device=device)
30 self.total_log_probs = torch.Tensor([0]).to(device=device)
31 self.tp_pg = tp_pg
32 self.dp_pg = dp_pg
33 self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
34 self.tokenizer = tokenizer
35 self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
36 self.batch_shift = 0
37 self.type_ids = None
38 if dataset_types is not None:
39 self.dataset_types = dataset_types
40 self.total_type_count = len(dataset_types)
41 self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
42 self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
43
44 self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types)
45
46 def set_current_type_ids(self, type_ids: torch.Tensor):
47 self.batch_shift = 0
48 self.type_ids = type_ids.cuda()
49
50 def __call__(self, logits, labels):
51 return self.update(logits, labels, type_ids=self.type_ids)
52
53 def update(self, logits, labels, type_ids=None):
54 if gpc.config.model.use_flash_attn:
55 micro_bsz = labels.size(0)
56 else:
57 micro_bsz = 1
58 if type_ids is not None:
59 type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
60 self.batch_shift += 1
61 self.loss_with_type_id.update(logits, labels, type_ids)
62
63 with torch.no_grad():
64 if isinstance(logits, (list, tuple)):
65 logits = logits[0]
66
67 logits = logits.detach().clone()
68 labels = labels.detach().clone()
69

Callers 2

mainFunction · 0.90
evaluate_on_val_dlsFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected