(smoke=False, show=True)
| 54 | |
| 55 | |
| 56 | def main(smoke=False, show=True): |
| 57 | graph_path = SMOKE_GRAPH_PATH if smoke else WIKI_GRAPH_PATH |
| 58 | graph = nx.read_edgelist( |
| 59 | str(graph_path), |
| 60 | create_using=nx.DiGraph(), |
| 61 | nodetype=None, |
| 62 | data=[("weight", int)], |
| 63 | ) |
| 64 | |
| 65 | model = LINE(graph, embedding_size=8 if smoke else 128, order="second") |
| 66 | model.train(batch_size=2 if smoke else 1024, epochs=1 if smoke else 50, verbose=0 if smoke else 2) |
| 67 | embeddings = model.get_embeddings() |
| 68 | assert len(embeddings) > 0 |
| 69 | |
| 70 | if not smoke: |
| 71 | evaluate_embeddings(embeddings, WIKI_LABEL_PATH) |
| 72 | plot_embeddings(embeddings, WIKI_LABEL_PATH, show=show) |
| 73 | |
| 74 | return embeddings |
| 75 | |
| 76 | |
| 77 | if __name__ == "__main__": |
no test coverage detected