| 200 | |
| 201 | |
| 202 | def test_mle(): |
| 203 | N = np.random.randint(2, 5) |
| 204 | gold = MLEGold(N, unk=True, filter_stopwords=False, filter_punctuation=False) |
| 205 | mine = MLENGram(N, unk=True, filter_stopwords=False, filter_punctuation=False) |
| 206 | |
| 207 | with tempfile.NamedTemporaryFile() as temp: |
| 208 | temp.write(bytes(" ".join(random_paragraph(1000)), encoding="utf-8-sig")) |
| 209 | gold.train(temp.name, encoding="utf-8-sig") |
| 210 | mine.train(temp.name, encoding="utf-8-sig") |
| 211 | |
| 212 | for k in mine.counts[N].keys(): |
| 213 | if k[0] == k[1] and k[0] in ("<bol>", "<eol>"): |
| 214 | continue |
| 215 | |
| 216 | err_str = "{}, mine: {}, gold: {}" |
| 217 | assert mine.counts[N][k] == gold.counts[N][k], err_str.format( |
| 218 | k, mine.counts[N][k], gold.counts[N][k] |
| 219 | ) |
| 220 | |
| 221 | M = mine.log_prob(k, N) |
| 222 | G = gold.log_prob(k, N) / np.log2(np.e) # convert to log base e |
| 223 | np.testing.assert_allclose(M, G) |
| 224 | print("PASSED") |
| 225 | |
| 226 | |
| 227 | def test_additive(): |