Internal: decode inside. Args: logits: TODO. logits_lengths: Lengths of logits.
(
self,
logits: torch.Tensor,
logits_lengths: torch.Tensor,
)
| 254 | |
| 255 | |
| 256 | def _decode_inside( |
| 257 | self, |
| 258 | logits: torch.Tensor, |
| 259 | logits_lengths: torch.Tensor, |
| 260 | ): |
| 261 | """Internal: decode inside. |
| 262 | |
| 263 | Args: |
| 264 | logits: TODO. |
| 265 | logits_lengths: Lengths of logits. |
| 266 | """ |
| 267 | hyps = self.beam_search(logits, logits_lengths, self.keywords_idxset) |
| 268 | |
| 269 | hit_keyword = None |
| 270 | hit_score = 1.0 |
| 271 | # start = 0; end = 0 |
| 272 | for one_hyp in hyps: |
| 273 | prefix_ids = one_hyp[0] |
| 274 | # path_score = one_hyp[1] |
| 275 | prefix_nodes = one_hyp[2] |
| 276 | assert len(prefix_ids) == len(prefix_nodes) |
| 277 | for word in self.keywords_token.keys(): |
| 278 | lab = self.keywords_token[word]['token_id'] |
| 279 | offset = self.is_sublist(prefix_ids, lab) |
| 280 | if offset != -1: |
| 281 | hit_keyword = word |
| 282 | for idx in range(offset, offset + len(lab)): |
| 283 | hit_score *= prefix_nodes[idx]['prob'] |
| 284 | break |
| 285 | if hit_keyword is not None: |
| 286 | hit_score = math.sqrt(hit_score) |
| 287 | break |
| 288 | |
| 289 | if hit_keyword is not None: |
| 290 | return True, hit_keyword, hit_score |
| 291 | else: |
| 292 | return False, None, None |
| 293 | |
| 294 | |
| 295 | def decode(self, x: torch.Tensor): |
no test coverage detected