MCPcopy
hub / github.com/HIT-SCIR/ltp / _srl_post

Method _srl_post

python/interface/ltp/nerual.py:315–358  ·  view source on GitHub ↗
(
        self,
        result: TokenClassifierResult,
        hidden: Dict[str, torch.Tensor],
        store: Dict[str, Any],
        inputs: List[str] = None,
        tokenized: BatchEncoding = None,
    )

Source from the content-addressed store, hash-verified

313
314 @no_grad
315 def _srl_post(
316 self,
317 result: TokenClassifierResult,
318 hidden: Dict[str, torch.Tensor],
319 store: Dict[str, Any],
320 inputs: List[str] = None,
321 tokenized: BatchEncoding = None,
322 ) -> LTPOutput:
323 crf = result.crf
324 logits = result.logits
325 attention_mask = result.attention_mask
326
327 length = torch.sum(attention_mask, dim=-1)
328
329 # to expand
330 attention_mask = attention_mask.unsqueeze(-1).expand(-1, -1, attention_mask.size(1))
331 attention_mask = attention_mask & torch.transpose(attention_mask, -1, -2)
332 attention_mask = attention_mask.flatten(end_dim=1)
333
334 index = attention_mask[:, 0]
335 attention_mask = attention_mask[index]
336 logits = logits.flatten(end_dim=1)[index]
337
338 if crf is None:
339 decoded = logits.argmax(dim=-1)
340 decoded = decoded.cpu().numpy()
341 attention_mask = attention_mask.cpu().numpy()
342 decoded = [
343 [self.srl_vocab[tag] for tag, mask in zip(tags, masks) if mask]
344 for tags, masks in zip(decoded, attention_mask)
345 ]
346 else:
347 logits = torch.log_softmax(logits, dim=-1)
348 decoded = crf.decode(logits, attention_mask)
349 decoded = [[self.srl_vocab[tag] for tag in tags] for tags in decoded]
350
351 length = length.cpu().numpy()
352
353 res = []
354 for l in length:
355 res.append(decoded[:l])
356 decoded = decoded[l:]
357
358 return res
359
360 @no_grad
361 def _dep_post(

Callers

nothing calls this directly

Calls 2

cpuMethod · 0.80
decodeMethod · 0.80

Tested by

no test coverage detected