(self, corpus_fp, vocab=None, encoding=None)
| 123 | } |
| 124 | |
| 125 | def train(self, corpus_fp, vocab=None, encoding=None): |
| 126 | N = self.N |
| 127 | H = self.hyperparameters |
| 128 | models, counts = {}, {} |
| 129 | grams = {n: [] for n in range(1, N + 1)} |
| 130 | gg = {n: [] for n in range(1, N + 1)} |
| 131 | filter_punc, filter_stop = H["filter_punctuation"], H["filter_stopwords"] |
| 132 | |
| 133 | n_words = 0 |
| 134 | tokens = set() |
| 135 | |
| 136 | with open(corpus_fp, "r", encoding=encoding) as text: |
| 137 | for line in text: |
| 138 | words = tokenize_words(line, filter_punc, filter_stop) |
| 139 | |
| 140 | if vocab is not None: |
| 141 | words = vocab.filter(words, H["unk"]) |
| 142 | |
| 143 | if len(words) == 0: |
| 144 | continue |
| 145 | |
| 146 | n_words += len(words) |
| 147 | tokens.update(words) |
| 148 | |
| 149 | # calculate n, n-1, ... 1-grams |
| 150 | for n in range(1, N + 1): |
| 151 | grams[n].append( |
| 152 | nltk.ngrams( |
| 153 | words, |
| 154 | n, |
| 155 | pad_left=True, |
| 156 | pad_right=True, |
| 157 | left_pad_symbol="<bol>", |
| 158 | right_pad_symbol="<eol>", |
| 159 | ) |
| 160 | ) |
| 161 | |
| 162 | gg[n].extend( |
| 163 | list( |
| 164 | nltk.ngrams( |
| 165 | words, |
| 166 | n, |
| 167 | pad_left=True, |
| 168 | pad_right=True, |
| 169 | left_pad_symbol="<bol>", |
| 170 | right_pad_symbol="<eol>", |
| 171 | ) |
| 172 | ) |
| 173 | ) |
| 174 | |
| 175 | for n in range(1, N + 1): |
| 176 | counts[n] = nltk.FreqDist(gg[n]) |
| 177 | models[n] = nltk.lm.Lidstone(order=n, gamma=self.K) |
| 178 | models[n].fit(grams[n], tokens) |
| 179 | |
| 180 | self.counts = counts |
| 181 | self._models = models |
| 182 | self.n_words = n_words |
no test coverage detected