@return: Conditional log probability of each option
(self, batch)
| 91 | ) |
| 92 | |
| 93 | def cond_log_prob(self, batch) -> List[List[float]]: |
| 94 | """ |
| 95 | @return: Conditional log probability of each option |
| 96 | """ |
| 97 | tokens, position_ids, attention_mask = self.process_data(batch, self.device) |
| 98 | choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"] |
| 99 | is_single_token = batch["is_single_token"] |
| 100 | |
| 101 | self.model.eval() |
| 102 | with torch.no_grad(): |
| 103 | logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None) |
| 104 | logits_batch = torch.nn.functional.log_softmax(logits, dim=-1) |
| 105 | |
| 106 | # output: [b, sq, vocab] |
| 107 | log_probs = [] |
| 108 | |
| 109 | if is_single_token: # Single token |
| 110 | for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch): |
| 111 | log_probs.append(logits[choice_target_ids[0], choices].tolist()) |
| 112 | else: # Multi token |
| 113 | for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch): |
| 114 | log_probs_single = [] |
| 115 | for choice, choice_target_id in zip(choices, choice_target_ids): |
| 116 | tmp = output[choice_target_id, choice] |
| 117 | log_probs_single.append(tmp.sum().tolist()) |
| 118 | log_probs.append(log_probs_single) |
| 119 | return log_probs |
| 120 | |
| 121 | def generate_text(self, sample, strategy, return_all_beams=False) -> Union[ |
| 122 | List[List[int]], List[List[List[int]]]]: |
no test coverage detected