(smoke=False, show=True)
| 54 | |
| 55 | |
| 56 | def main(smoke=False, show=True): |
| 57 | graph_path = SMOKE_GRAPH_PATH if smoke else WIKI_GRAPH_PATH |
| 58 | graph = nx.read_edgelist( |
| 59 | str(graph_path), |
| 60 | create_using=nx.DiGraph(), |
| 61 | nodetype=None, |
| 62 | data=[("weight", int)], |
| 63 | ) |
| 64 | |
| 65 | model = SDNE(graph, hidden_size=[8, 4] if smoke else [256, 128]) |
| 66 | model.train( |
| 67 | batch_size=2 if smoke else 3000, |
| 68 | epochs=1 if smoke else 40, |
| 69 | verbose=0 if smoke else 2, |
| 70 | ) |
| 71 | embeddings = model.get_embeddings() |
| 72 | assert len(embeddings) > 0 |
| 73 | |
| 74 | if not smoke: |
| 75 | evaluate_embeddings(embeddings, WIKI_LABEL_PATH) |
| 76 | plot_embeddings(embeddings, WIKI_LABEL_PATH, show=show) |
| 77 | |
| 78 | return embeddings |
| 79 | |
| 80 | |
| 81 | if __name__ == "__main__": |
no test coverage detected