| 313 | |
| 314 | @no_grad |
| 315 | def _srl_post( |
| 316 | self, |
| 317 | result: TokenClassifierResult, |
| 318 | hidden: Dict[str, torch.Tensor], |
| 319 | store: Dict[str, Any], |
| 320 | inputs: List[str] = None, |
| 321 | tokenized: BatchEncoding = None, |
| 322 | ) -> LTPOutput: |
| 323 | crf = result.crf |
| 324 | logits = result.logits |
| 325 | attention_mask = result.attention_mask |
| 326 | |
| 327 | length = torch.sum(attention_mask, dim=-1) |
| 328 | |
| 329 | # to expand |
| 330 | attention_mask = attention_mask.unsqueeze(-1).expand(-1, -1, attention_mask.size(1)) |
| 331 | attention_mask = attention_mask & torch.transpose(attention_mask, -1, -2) |
| 332 | attention_mask = attention_mask.flatten(end_dim=1) |
| 333 | |
| 334 | index = attention_mask[:, 0] |
| 335 | attention_mask = attention_mask[index] |
| 336 | logits = logits.flatten(end_dim=1)[index] |
| 337 | |
| 338 | if crf is None: |
| 339 | decoded = logits.argmax(dim=-1) |
| 340 | decoded = decoded.cpu().numpy() |
| 341 | attention_mask = attention_mask.cpu().numpy() |
| 342 | decoded = [ |
| 343 | [self.srl_vocab[tag] for tag, mask in zip(tags, masks) if mask] |
| 344 | for tags, masks in zip(decoded, attention_mask) |
| 345 | ] |
| 346 | else: |
| 347 | logits = torch.log_softmax(logits, dim=-1) |
| 348 | decoded = crf.decode(logits, attention_mask) |
| 349 | decoded = [[self.srl_vocab[tag] for tag in tags] for tags in decoded] |
| 350 | |
| 351 | length = length.cpu().numpy() |
| 352 | |
| 353 | res = [] |
| 354 | for l in length: |
| 355 | res.append(decoded[:l]) |
| 356 | decoded = decoded[l:] |
| 357 | |
| 358 | return res |
| 359 | |
| 360 | @no_grad |
| 361 | def _dep_post( |