MCPcopy Index your code
hub / github.com/modelscope/FunASR / beam_search

Method beam_search

funasr/utils/kws_utils.py:125–229  ·  view source on GitHub ↗

CTC prefix beam search inner implementation Args: logits (torch.Tensor): (1, max_len, vocab_size) logits_lengths (torch.Tensor): (1, ) keywords_tokenset (set): token set for filtering score score_beam_size (int): beam size for score

(
        self,
        logits: torch.Tensor,
        logits_lengths: torch.Tensor,
        keywords_tokenset: set = None,
        score_beam_size: int = 3,
        path_beam_size: int = 20,
    )

Source from the content-addressed store, hash-verified

123 [ self.keywords_idxset.add(i) for i in indexs ]
124
125 def beam_search(
126 self,
127 logits: torch.Tensor,
128 logits_lengths: torch.Tensor,
129 keywords_tokenset: set = None,
130 score_beam_size: int = 3,
131 path_beam_size: int = 20,
132 ) -> Tuple[List[List[int]], torch.Tensor]:
133 """ CTC prefix beam search inner implementation
134
135 Args:
136 logits (torch.Tensor): (1, max_len, vocab_size)
137 logits_lengths (torch.Tensor): (1, )
138 keywords_tokenset (set): token set for filtering score
139 score_beam_size (int): beam size for score
140 path_beam_size (int): beam size for path
141
142 Returns:
143 List[List[int]]: nbest results
144 """
145
146 maxlen = logits.size(0)
147 ctc_probs = logits
148 cur_hyps = [(tuple(), (1.0, 0.0, []))]
149
150 # CTC beam search step by step
151 for t in range(0, maxlen):
152 probs = ctc_probs[t] # (vocab_size,)
153 # key: prefix, value (pb, pnb), default value(-inf, -inf)
154 next_hyps = defaultdict(lambda: (0.0, 0.0, []))
155
156 # 2.1 First beam prune: select topk best
157 top_k_probs, top_k_index = probs.topk(
158 score_beam_size) # (score_beam_size,)
159
160 # filter prob score that is too small
161 filter_probs = []
162 filter_index = []
163 for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
164 if keywords_tokenset is not None:
165 if prob > 0.05 and idx in keywords_tokenset:
166 filter_probs.append(prob)
167 filter_index.append(idx)
168 else:
169 if prob > 0.05:
170 filter_probs.append(prob)
171 filter_index.append(idx)
172
173 if len(filter_index) == 0:
174 continue
175
176 for s in filter_index:
177 ps = probs[s].item()
178 # print(f'frame:{t}, token:{s}, score:{ps}')
179
180 for prefix, (pb, pnb, cur_nodes) in cur_hyps:
181 last = prefix[-1] if len(prefix) > 0 else None
182 if s == 0: # blank

Callers 13

_decode_insideMethod · 0.95
inferenceMethod · 0.80
inferenceMethod · 0.80
inferenceMethod · 0.80
inferenceMethod · 0.80
inferenceMethod · 0.80
inferenceMethod · 0.80
generate_chunkMethod · 0.80
inferenceMethod · 0.80
inferenceMethod · 0.80
generate_chunkMethod · 0.80
inferenceMethod · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected