| 58 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") |
| 59 | |
| 60 | def bpe(self, token): |
| 61 | if token in self.cache: |
| 62 | return self.cache[token] |
| 63 | word = tuple(token) |
| 64 | pairs = get_pairs(word) |
| 65 | |
| 66 | if not pairs: |
| 67 | return token |
| 68 | |
| 69 | while True: |
| 70 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) |
| 71 | if bigram not in self.bpe_ranks: |
| 72 | break |
| 73 | first, second = bigram |
| 74 | new_word = [] |
| 75 | i = 0 |
| 76 | while i < len(word): |
| 77 | try: |
| 78 | j = word.index(first, i) |
| 79 | new_word.extend(word[i:j]) |
| 80 | i = j |
| 81 | except: |
| 82 | new_word.extend(word[i:]) |
| 83 | break |
| 84 | |
| 85 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: |
| 86 | new_word.append(first + second) |
| 87 | i += 2 |
| 88 | else: |
| 89 | new_word.append(word[i]) |
| 90 | i += 1 |
| 91 | new_word = tuple(new_word) |
| 92 | word = new_word |
| 93 | if len(word) == 1: |
| 94 | break |
| 95 | else: |
| 96 | pairs = get_pairs(word) |
| 97 | word = " ".join(word) |
| 98 | self.cache[token] = word |
| 99 | return word |
| 100 | |
| 101 | def encode(self, text): |
| 102 | bpe_tokens = [] |