(we_file, w2i_file, use_brown=True, n_files=100)
| 237 | |
| 238 | |
| 239 | def main(we_file, w2i_file, use_brown=True, n_files=100): |
| 240 | if use_brown: |
| 241 | cc_matrix = "cc_matrix_brown.npy" |
| 242 | else: |
| 243 | cc_matrix = "cc_matrix_%s.npy" % n_files |
| 244 | |
| 245 | # hacky way of checking if we need to re-load the raw data or not |
| 246 | # remember, only the co-occurrence matrix is needed for training |
| 247 | if os.path.exists(cc_matrix): |
| 248 | with open(w2i_file) as f: |
| 249 | word2idx = json.load(f) |
| 250 | sentences = [] # dummy - we won't actually use it |
| 251 | else: |
| 252 | if use_brown: |
| 253 | keep_words = set([ |
| 254 | 'king', 'man', 'woman', |
| 255 | 'france', 'paris', 'london', 'rome', 'italy', 'britain', 'england', |
| 256 | 'french', 'english', 'japan', 'japanese', 'chinese', 'italian', |
| 257 | 'australia', 'australian', 'december', 'november', 'june', |
| 258 | 'january', 'february', 'march', 'april', 'may', 'july', 'august', |
| 259 | 'september', 'october', |
| 260 | ]) |
| 261 | sentences, word2idx = get_sentences_with_word2idx_limit_vocab(n_vocab=5000, keep_words=keep_words) |
| 262 | else: |
| 263 | sentences, word2idx = get_wikipedia_data(n_files=n_files, n_vocab=2000) |
| 264 | |
| 265 | with open(w2i_file, 'w') as f: |
| 266 | json.dump(word2idx, f) |
| 267 | |
| 268 | V = len(word2idx) |
| 269 | model = Glove(100, V, 10) |
| 270 | |
| 271 | # alternating least squares method |
| 272 | model.fit(sentences, cc_matrix=cc_matrix, epochs=20) |
| 273 | |
| 274 | # gradient descent method |
| 275 | # model.fit( |
| 276 | # sentences, |
| 277 | # cc_matrix=cc_matrix, |
| 278 | # learning_rate=5e-4, |
| 279 | # reg=0.1, |
| 280 | # epochs=500, |
| 281 | # gd=True, |
| 282 | # ) |
| 283 | model.save(we_file) |
| 284 | |
| 285 | |
| 286 | if __name__ == '__main__': |
no test coverage detected