(smoke=False, show=True)
| 48 | |
| 49 | |
| 50 | def main(smoke=False, show=True): |
| 51 | graph_path = SMOKE_GRAPH_PATH if smoke else WIKI_GRAPH_PATH |
| 52 | graph = nx.read_edgelist( |
| 53 | str(graph_path), |
| 54 | create_using=nx.DiGraph(), |
| 55 | nodetype=None, |
| 56 | data=[("weight", int)], |
| 57 | ) |
| 58 | |
| 59 | model = DeepWalk( |
| 60 | graph, |
| 61 | walk_length=3 if smoke else 10, |
| 62 | num_walks=2 if smoke else 80, |
| 63 | workers=1, |
| 64 | ) |
| 65 | model.train(window_size=2 if smoke else 5, iter=1 if smoke else 3, workers=1) |
| 66 | embeddings = model.get_embeddings() |
| 67 | assert len(embeddings) > 0 |
| 68 | |
| 69 | if not smoke: |
| 70 | evaluate_embeddings(embeddings, WIKI_LABEL_PATH) |
| 71 | plot_embeddings(embeddings, WIKI_LABEL_PATH, show=show) |
| 72 | |
| 73 | return embeddings |
| 74 | |
| 75 | |
| 76 | if __name__ == "__main__": |
no test coverage detected