| 93 | full_state_update: bool = False |
| 94 | |
| 95 | def update(self, result: TokenClassifierResult, labels: Tensor, **kwargs) -> None: |
| 96 | crf = result.crf |
| 97 | logits = result.logits |
| 98 | attention_mask = result.attention_mask |
| 99 | |
| 100 | # to expand |
| 101 | attention_mask = attention_mask.unsqueeze(-1).expand(-1, -1, attention_mask.size(1)) |
| 102 | attention_mask = attention_mask & torch.transpose(attention_mask, -1, -2) |
| 103 | attention_mask = attention_mask.flatten(end_dim=1) |
| 104 | |
| 105 | index = attention_mask[:, 0] |
| 106 | attention_mask = attention_mask[index] |
| 107 | logits = logits.flatten(end_dim=1)[index] |
| 108 | labels = labels.flatten(end_dim=1)[index] |
| 109 | |
| 110 | labels = labels.cpu().numpy() |
| 111 | labels = [ |
| 112 | [self.labels[tag] for tag, mask in zip(tags, masks) if mask] |
| 113 | for tags, masks in zip(labels, attention_mask) |
| 114 | ] |
| 115 | |
| 116 | if crf is None: |
| 117 | decoded = logits.argmax(dim=-1) |
| 118 | decoded = decoded.cpu().numpy() |
| 119 | decoded = [ |
| 120 | [self.labels[tag] for tag, mask in zip(tags, masks) if mask] |
| 121 | for tags, masks in zip(decoded, attention_mask) |
| 122 | ] |
| 123 | else: |
| 124 | logits = torch.log_softmax(logits, dim=-1) |
| 125 | decoded = crf.decode(logits, attention_mask) |
| 126 | decoded = [[self.labels[tag] for tag in tags] for tags in decoded] |
| 127 | |
| 128 | gold_entities = get_entities(self.concat(labels)) |
| 129 | pred_entities = get_entities(self.concat(decoded)) |
| 130 | |
| 131 | self.correct += len(set(gold_entities) & set(pred_entities)) |
| 132 | self.gold_total += len(gold_entities) |
| 133 | self.pred_total += len(pred_entities) |