| 819 | return word2idx, idx2word, tokens, doc_count |
| 820 | |
| 821 | def _keep_top_n_tokens(self): |
| 822 | N = self.hyperparameters["max_tokens"] |
| 823 | doc_counts, word2idx, idx2word = {}, {}, {} |
| 824 | tokens = sorted(self._tokens, key=lambda x: x.count, reverse=True) |
| 825 | |
| 826 | # reindex the top-N tokens... |
| 827 | unk_ix = None |
| 828 | for idx, tt in enumerate(tokens[:N]): |
| 829 | word2idx[tt.word] = idx |
| 830 | idx2word[idx] = tt.word |
| 831 | |
| 832 | if tt.word == "<unk>": |
| 833 | unk_ix = idx |
| 834 | |
| 835 | # ... if <unk> isn't in the top-N, add it, replacing the Nth |
| 836 | # most-frequent word and adjust the <unk> count accordingly ... |
| 837 | if unk_ix is None: |
| 838 | unk_ix = self.token2idx["<unk>"] |
| 839 | old_count = tokens[N - 1].count |
| 840 | tokens[N - 1] = self._tokens[unk_ix] |
| 841 | tokens[N - 1].count += old_count |
| 842 | word2idx["<unk>"] = N - 1 |
| 843 | idx2word[N - 1] = "<unk>" |
| 844 | |
| 845 | # ... and recode all dropped tokens as "<unk>" |
| 846 | for tt in tokens[N:]: |
| 847 | tokens[unk_ix].count += tt.count |
| 848 | |
| 849 | # ... finally, reindex the word counts for each document |
| 850 | doc_counts = {} |
| 851 | for d_ix in self.term_freq.keys(): |
| 852 | doc_counts[d_ix] = {} |
| 853 | for old_ix, d_count in self.term_freq[d_ix].items(): |
| 854 | word = self.idx2token[old_ix] |
| 855 | new_ix = word2idx.get(word, unk_ix) |
| 856 | doc_counts[d_ix][new_ix] = doc_counts[d_ix].get(new_ix, 0) + d_count |
| 857 | |
| 858 | self._tokens = tokens[:N] |
| 859 | self.token2idx = word2idx |
| 860 | self.idx2token = idx2word |
| 861 | self.term_freq = doc_counts |
| 862 | |
| 863 | assert len(self._tokens) <= N |
| 864 | |
| 865 | def _drop_low_freq_tokens(self): |
| 866 | """ |