| 45 | self.add_state("pred_total", default=tensor(0), dist_reduce_fx="sum") |
| 46 | |
| 47 | def update(self, result: TokenClassifierResult, labels: Tensor, **kwargs) -> None: |
| 48 | crf = result.crf |
| 49 | logits = result.logits |
| 50 | attention_mask = result.attention_mask |
| 51 | |
| 52 | labels = labels.cpu().numpy() |
| 53 | labels = [ |
| 54 | [self.labels[tag] for tag, mask in zip(tags, masks) if mask] |
| 55 | for tags, masks in zip(labels, attention_mask) |
| 56 | ] |
| 57 | |
| 58 | if crf is None: |
| 59 | decoded = logits.argmax(dim=-1) |
| 60 | decoded = decoded.cpu().numpy() |
| 61 | attention_mask = attention_mask.cpu().numpy() |
| 62 | decoded = [ |
| 63 | [self.labels[tag] for tag, mask in zip(tags, masks) if mask] |
| 64 | for tags, masks in zip(decoded, attention_mask) |
| 65 | ] |
| 66 | else: |
| 67 | logits = torch.log_softmax(logits, dim=-1) |
| 68 | decoded = crf.decode(logits, attention_mask) |
| 69 | decoded = [[self.labels[tag] for tag in tags] for tags in decoded] |
| 70 | |
| 71 | gold_entities = get_entities(self.concat(labels)) |
| 72 | pred_entities = get_entities(self.concat(decoded)) |
| 73 | |
| 74 | self.correct += len(set(gold_entities) & set(pred_entities)) |
| 75 | self.gold_total += len(gold_entities) |
| 76 | self.pred_total += len(pred_entities) |
| 77 | |
| 78 | def compute(self) -> Any: |
| 79 | return 2 * self.correct / (self.pred_total + self.gold_total + 1e-6) |