MCPcopy
hub / github.com/zai-org/GLM-130B / ModelForEvaluation

Class ModelForEvaluation

evaluation/model.py:78–199  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

76
77
78class ModelForEvaluation(torch.nn.Module):
79 def __init__(self, model):
80 super().__init__()
81
82 self.model = model
83 self.device = next(self.model.parameters()).device
84
85 @staticmethod
86 def process_data(batch, device):
87 return (
88 batch["tokens"].to(device=device).long(),
89 batch["position_ids"].to(device=device).long(),
90 batch["attention_mask"].to(device=device).bool().unsqueeze(1),
91 )
92
93 def cond_log_prob(self, batch) -> List[List[float]]:
94 """
95 @return: Conditional log probability of each option
96 """
97 tokens, position_ids, attention_mask = self.process_data(batch, self.device)
98 choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
99 is_single_token = batch["is_single_token"]
100
101 self.model.eval()
102 with torch.no_grad():
103 logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
104 logits_batch = torch.nn.functional.log_softmax(logits, dim=-1)
105
106 # output: [b, sq, vocab]
107 log_probs = []
108
109 if is_single_token: # Single token
110 for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
111 log_probs.append(logits[choice_target_ids[0], choices].tolist())
112 else: # Multi token
113 for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
114 log_probs_single = []
115 for choice, choice_target_id in zip(choices, choice_target_ids):
116 tmp = output[choice_target_id, choice]
117 log_probs_single.append(tmp.sum().tolist())
118 log_probs.append(log_probs_single)
119 return log_probs
120
121 def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
122 List[List[int]], List[List[List[int]]]]:
123 """
124 @return: A list of text model generated, sorted by score in descending order
125 """
126
127 seqs = sample["tokens"].to(device=self.device).long()
128 context_lengths = sample["context_length"].long()
129
130 def get_masks_and_position_ids(seq):
131 batch_size = seq.shape[0]
132 max_gen_length = sample['target_position_ids'].shape[-1]
133 tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
134 position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
135 position_ids = position_ids.to(device=self.device).long()

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected