()
| 225 | |
| 226 | |
| 227 | def test_additive(): |
| 228 | K = np.random.rand() |
| 229 | N = np.random.randint(2, 5) |
| 230 | gold = AdditiveGold( |
| 231 | N, K, unk=True, filter_stopwords=False, filter_punctuation=False |
| 232 | ) |
| 233 | mine = AdditiveNGram( |
| 234 | N, K, unk=True, filter_stopwords=False, filter_punctuation=False |
| 235 | ) |
| 236 | |
| 237 | with tempfile.NamedTemporaryFile() as temp: |
| 238 | temp.write(bytes(" ".join(random_paragraph(1000)), encoding="utf-8-sig")) |
| 239 | gold.train(temp.name, encoding="utf-8-sig") |
| 240 | mine.train(temp.name, encoding="utf-8-sig") |
| 241 | |
| 242 | for k in mine.counts[N].keys(): |
| 243 | if k[0] == k[1] and k[0] in ("<bol>", "<eol>"): |
| 244 | continue |
| 245 | |
| 246 | err_str = "{}, mine: {}, gold: {}" |
| 247 | assert mine.counts[N][k] == gold.counts[N][k], err_str.format( |
| 248 | k, mine.counts[N][k], gold.counts[N][k] |
| 249 | ) |
| 250 | |
| 251 | M = mine.log_prob(k, N) |
| 252 | G = gold.log_prob(k, N) / np.log2(np.e) # convert to log base e |
| 253 | np.testing.assert_allclose(M, G) |
| 254 | print("PASSED") |
nothing calls this directly
no test coverage detected