information about the beams at specific time-step
| 20 | self.labeling = () # beam-labeling |
| 21 | |
| 22 | class BeamState: |
| 23 | "information about the beams at specific time-step" |
| 24 | def __init__(self): |
| 25 | self.entries = {} |
| 26 | |
| 27 | def norm(self): |
| 28 | "length-normalise LM score" |
| 29 | for (k, _) in self.entries.items(): |
| 30 | labelingLen = len(self.entries[k].labeling) |
| 31 | self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) |
| 32 | |
| 33 | def sort(self): |
| 34 | "return beam-labelings, sorted by probability" |
| 35 | beams = [v for (_, v) in self.entries.items()] |
| 36 | sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) |
| 37 | return [x.labeling for x in sortedBeams] |
| 38 | |
| 39 | def wordsearch(self, classes, ignore_idx, beamWidth, dict_list): |
| 40 | beams = [v for (_, v) in self.entries.items()] |
| 41 | sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)[:beamWidth] |
| 42 | |
| 43 | for j, candidate in enumerate(sortedBeams): |
| 44 | idx_list = candidate.labeling |
| 45 | text = '' |
| 46 | for i,l in enumerate(idx_list): |
| 47 | if l not in ignore_idx and (not (i > 0 and idx_list[i - 1] == idx_list[i])): # removing repeated characters and blank. |
| 48 | text += classes[l] |
| 49 | |
| 50 | if j == 0: best_text = text |
| 51 | if text in dict_list: |
| 52 | print('found text: ', text) |
| 53 | best_text = text |
| 54 | break |
| 55 | else: |
| 56 | print('not in dict: ', text) |
| 57 | return best_text |
| 58 | |
| 59 | def applyLM(parentBeam, childBeam, classes, lm): |
| 60 | "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars" |
no outgoing calls
no test coverage detected
searching dependent graphs…