| 186 | return output_tokens_list |
| 187 | |
| 188 | def shrink_beams(tokens, mems, nb, score): |
| 189 | # beam search is a failed attempt, will be removed soon... |
| 190 | if tokens.shape[0] == nb: |
| 191 | return tokens, mems, score |
| 192 | # shrink |
| 193 | maximum = max(score) |
| 194 | max_idx = score.index(maximum) |
| 195 | tokens = tokens[max_idx].unsqueeze(0) |
| 196 | score = [0] |
| 197 | new_mems = [mem[max_idx: max_idx + 1] for mem in mems] |
| 198 | return tokens, new_mems, score |
| 199 | |
| 200 | def add_interlacing_beam_marks(seq, nb=12, period=3000): |
| 201 | assert isinstance(seq, list) or len(seq.shape) == 1 |