| 76 | |
| 77 | |
| 78 | class ModelForEvaluation(torch.nn.Module): |
| 79 | def __init__(self, model): |
| 80 | super().__init__() |
| 81 | |
| 82 | self.model = model |
| 83 | self.device = next(self.model.parameters()).device |
| 84 | |
| 85 | @staticmethod |
| 86 | def process_data(batch, device): |
| 87 | return ( |
| 88 | batch["tokens"].to(device=device).long(), |
| 89 | batch["position_ids"].to(device=device).long(), |
| 90 | batch["attention_mask"].to(device=device).bool().unsqueeze(1), |
| 91 | ) |
| 92 | |
| 93 | def cond_log_prob(self, batch) -> List[List[float]]: |
| 94 | """ |
| 95 | @return: Conditional log probability of each option |
| 96 | """ |
| 97 | tokens, position_ids, attention_mask = self.process_data(batch, self.device) |
| 98 | choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"] |
| 99 | is_single_token = batch["is_single_token"] |
| 100 | |
| 101 | self.model.eval() |
| 102 | with torch.no_grad(): |
| 103 | logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None) |
| 104 | logits_batch = torch.nn.functional.log_softmax(logits, dim=-1) |
| 105 | |
| 106 | # output: [b, sq, vocab] |
| 107 | log_probs = [] |
| 108 | |
| 109 | if is_single_token: # Single token |
| 110 | for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch): |
| 111 | log_probs.append(logits[choice_target_ids[0], choices].tolist()) |
| 112 | else: # Multi token |
| 113 | for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch): |
| 114 | log_probs_single = [] |
| 115 | for choice, choice_target_id in zip(choices, choice_target_ids): |
| 116 | tmp = output[choice_target_id, choice] |
| 117 | log_probs_single.append(tmp.sum().tolist()) |
| 118 | log_probs.append(log_probs_single) |
| 119 | return log_probs |
| 120 | |
| 121 | def generate_text(self, sample, strategy, return_all_beams=False) -> Union[ |
| 122 | List[List[int]], List[List[List[int]]]]: |
| 123 | """ |
| 124 | @return: A list of text model generated, sorted by score in descending order |
| 125 | """ |
| 126 | |
| 127 | seqs = sample["tokens"].to(device=self.device).long() |
| 128 | context_lengths = sample["context_length"].long() |
| 129 | |
| 130 | def get_masks_and_position_ids(seq): |
| 131 | batch_size = seq.shape[0] |
| 132 | max_gen_length = sample['target_position_ids'].shape[-1] |
| 133 | tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1) |
| 134 | position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1) |
| 135 | position_ids = position_ids.to(device=self.device).long() |