MCPcopy
hub / github.com/InternLM/InternLM / update

Method update

internlm/model/metrics.py:53–138  ·  view source on GitHub ↗
(self, logits, labels, type_ids=None)

Source from the content-addressed store, hash-verified

51 return self.update(logits, labels, type_ids=self.type_ids)
52
53 def update(self, logits, labels, type_ids=None):
54 if gpc.config.model.use_flash_attn:
55 micro_bsz = labels.size(0)
56 else:
57 micro_bsz = 1
58 if type_ids is not None:
59 type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
60 self.batch_shift += 1
61 self.loss_with_type_id.update(logits, labels, type_ids)
62
63 with torch.no_grad():
64 if isinstance(logits, (list, tuple)):
65 logits = logits[0]
66
67 logits = logits.detach().clone()
68 labels = labels.detach().clone()
69
70 if self.tokenizer: # need to calculate bits per bytes
71 sequences = self.tokenizer.decode_ids(labels.tolist())
72 self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences))
73
74 shift_logits = logits.view(-1, logits.size(-1))
75 shift_labels = labels.view(-1)
76 # There is a shift according to the current rank, because the logits are split
77 pred_shift = self.tp_local_rank * logits.shape[-1]
78
79 logits_max = torch.max(shift_logits, dim=-1)[0]
80 torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg)
81 # Determine whether the maximum value of the current local tensor is the global maximum value
82 logits_global = logits_max == torch.max(shift_logits, dim=-1)[0]
83
84 corrects = torch.logical_and(
85 (shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global
86 ).long()
87 mask = shift_labels.ne(-100).long()
88 if hasattr(self, "total_type_count"):
89 ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum")
90 token_num_type = scatter(mask, type_ids, dim=0, reduce="sum")
91 if len(ds_acc) < self.total_type_count:
92 ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))])
93 token_num_type = torch.cat(
94 [token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
95 )
96 self.ds_tokens += token_num_type
97 sync_tensor = ds_acc
98 torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
99 self.ds_right += sync_tensor.view(-1)
100
101 acc = corrects.sum()
102 torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
103 self.right += acc # Masked_fill is not needed here because -100 is not available anyway
104 self.total += mask.sum()
105
106 # Subtract the maximum value.
107 shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1))
108
109 # Get the partition's vocab indecies
110 partition_vocab_size = shift_logits.size()[-1]

Callers 9

__call__Method · 0.95
get_vocabMethod · 0.45
generate_interactiveFunction · 0.45
_get_clientMethod · 0.45
get_metricMethod · 0.45
get_metricMethod · 0.45

Calls 1

logMethod · 0.80

Tested by

no test coverage detected