(model, seq, args)
| 212 | |
| 213 | |
| 214 | def inverse_prompt_score(model, seq, args): |
| 215 | tokenizer = get_tokenizer() |
| 216 | device = seq.device |
| 217 | assert len(seq.shape) == 2 |
| 218 | |
| 219 | botext = 2 + 1024 + 1 |
| 220 | assert tokenizer['[ROI1]'] == seq[0][botext] |
| 221 | |
| 222 | tokens, attention_mask, position_ids = get_batch(seq, device, args) |
| 223 | logits, *mems = model(tokens, position_ids, attention_mask, None, None, is_sparse=args.is_sparse) |
| 224 | logits[..., :tokenizer.img_tokenizer.num_tokens] = -float('Inf') |
| 225 | log_probs = torch.log(F.softmax(logits, dim=-1)) |
| 226 | |
| 227 | pred = log_probs[:, botext:-1, :] |
| 228 | target = tokens[:, botext+1:].unsqueeze(-1) |
| 229 | scores = torch.gather(pred, dim=2, index=target).squeeze(-1).sum(dim=-1) |
| 230 | return scores |
| 231 |
no test coverage detected