(smoke=False, show=True)
| 48 | |
| 49 | |
| 50 | def main(smoke=False, show=True): |
| 51 | graph_path = SMOKE_GRAPH_PATH if smoke else WIKI_GRAPH_PATH |
| 52 | graph = nx.read_edgelist( |
| 53 | str(graph_path), |
| 54 | create_using=nx.DiGraph(), |
| 55 | nodetype=None, |
| 56 | data=[("weight", int)], |
| 57 | ) |
| 58 | |
| 59 | model = Node2Vec( |
| 60 | graph, |
| 61 | walk_length=3 if smoke else 10, |
| 62 | num_walks=2 if smoke else 80, |
| 63 | p=0.25, |
| 64 | q=4, |
| 65 | workers=1, |
| 66 | use_rejection_sampling=False, |
| 67 | ) |
| 68 | model.train(window_size=2 if smoke else 5, iter=1 if smoke else 3, workers=1) |
| 69 | embeddings = model.get_embeddings() |
| 70 | assert len(embeddings) > 0 |
| 71 | |
| 72 | if not smoke: |
| 73 | evaluate_embeddings(embeddings, WIKI_LABEL_PATH) |
| 74 | plot_embeddings(embeddings, WIKI_LABEL_PATH, show=show) |
| 75 | |
| 76 | return embeddings |
| 77 | |
| 78 | |
| 79 | if __name__ == "__main__": |
no test coverage detected