CTC prefix beam search inner implementation Args: logits (torch.Tensor): (1, max_len, vocab_size) logits_lengths (torch.Tensor): (1, ) keywords_tokenset (set): token set for filtering score score_beam_size (int): beam size for score
(
self,
logits: torch.Tensor,
logits_lengths: torch.Tensor,
keywords_tokenset: set = None,
score_beam_size: int = 3,
path_beam_size: int = 20,
)
| 123 | [ self.keywords_idxset.add(i) for i in indexs ] |
| 124 | |
| 125 | def beam_search( |
| 126 | self, |
| 127 | logits: torch.Tensor, |
| 128 | logits_lengths: torch.Tensor, |
| 129 | keywords_tokenset: set = None, |
| 130 | score_beam_size: int = 3, |
| 131 | path_beam_size: int = 20, |
| 132 | ) -> Tuple[List[List[int]], torch.Tensor]: |
| 133 | """ CTC prefix beam search inner implementation |
| 134 | |
| 135 | Args: |
| 136 | logits (torch.Tensor): (1, max_len, vocab_size) |
| 137 | logits_lengths (torch.Tensor): (1, ) |
| 138 | keywords_tokenset (set): token set for filtering score |
| 139 | score_beam_size (int): beam size for score |
| 140 | path_beam_size (int): beam size for path |
| 141 | |
| 142 | Returns: |
| 143 | List[List[int]]: nbest results |
| 144 | """ |
| 145 | |
| 146 | maxlen = logits.size(0) |
| 147 | ctc_probs = logits |
| 148 | cur_hyps = [(tuple(), (1.0, 0.0, []))] |
| 149 | |
| 150 | # CTC beam search step by step |
| 151 | for t in range(0, maxlen): |
| 152 | probs = ctc_probs[t] # (vocab_size,) |
| 153 | # key: prefix, value (pb, pnb), default value(-inf, -inf) |
| 154 | next_hyps = defaultdict(lambda: (0.0, 0.0, [])) |
| 155 | |
| 156 | # 2.1 First beam prune: select topk best |
| 157 | top_k_probs, top_k_index = probs.topk( |
| 158 | score_beam_size) # (score_beam_size,) |
| 159 | |
| 160 | # filter prob score that is too small |
| 161 | filter_probs = [] |
| 162 | filter_index = [] |
| 163 | for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): |
| 164 | if keywords_tokenset is not None: |
| 165 | if prob > 0.05 and idx in keywords_tokenset: |
| 166 | filter_probs.append(prob) |
| 167 | filter_index.append(idx) |
| 168 | else: |
| 169 | if prob > 0.05: |
| 170 | filter_probs.append(prob) |
| 171 | filter_index.append(idx) |
| 172 | |
| 173 | if len(filter_index) == 0: |
| 174 | continue |
| 175 | |
| 176 | for s in filter_index: |
| 177 | ps = probs[s].item() |
| 178 | # print(f'frame:{t}, token:{s}, score:{ps}') |
| 179 | |
| 180 | for prefix, (pb, pnb, cur_nodes) in cur_hyps: |
| 181 | last = prefix[-1] if len(prefix) > 0 else None |
| 182 | if s == 0: # blank |
no outgoing calls
no test coverage detected