| 37 | self.max_len = 0 |
| 38 | |
| 39 | def bpe(self, token): |
| 40 | if token in self.cache: |
| 41 | return self.cache[token] |
| 42 | word = tuple(token) |
| 43 | pairs = get_pairs(word) |
| 44 | if not pairs: |
| 45 | return token |
| 46 | |
| 47 | while True: |
| 48 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) |
| 49 | if bigram not in self.bpe_ranks: |
| 50 | break |
| 51 | first, second = bigram |
| 52 | new_word = [] |
| 53 | i = 0 |
| 54 | while i < len(word): |
| 55 | try: |
| 56 | j = word.index(first, i) |
| 57 | new_word.extend(word[i:j]) |
| 58 | i = j |
| 59 | except: |
| 60 | new_word.extend(word[i:]) |
| 61 | break |
| 62 | |
| 63 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: |
| 64 | new_word.append(first + second) |
| 65 | i += 2 |
| 66 | else: |
| 67 | new_word.append(word[i]) |
| 68 | i += 1 |
| 69 | new_word = tuple(new_word) |
| 70 | word = new_word |
| 71 | if len(word) == 1: |
| 72 | break |
| 73 | else: |
| 74 | pairs = get_pairs(word) |
| 75 | word = ' '.join(word) |
| 76 | self.cache[token] = word |
| 77 | return word |
| 78 | |
| 79 | def encode(self, text): |
| 80 | return [self.encoder.get(token, 1) for token in self.tokenize(text)] |