(smoke=False, show=True)
| 49 | |
| 50 | |
| 51 | def main(smoke=False, show=True): |
| 52 | graph_path = SMOKE_GRAPH_PATH if smoke else FLIGHT_GRAPH_PATH |
| 53 | graph = nx.read_edgelist( |
| 54 | str(graph_path), |
| 55 | create_using=nx.DiGraph(), |
| 56 | nodetype=None, |
| 57 | data=[("weight", int)], |
| 58 | ) |
| 59 | |
| 60 | with tempfile.TemporaryDirectory(prefix="struc2vec-") as temp_dir: |
| 61 | model = Struc2Vec( |
| 62 | graph, |
| 63 | walk_length=3 if smoke else 10, |
| 64 | num_walks=1 if smoke else 80, |
| 65 | workers=1 if smoke else 4, |
| 66 | verbose=0 if smoke else 40, |
| 67 | temp_path=temp_dir + "/", |
| 68 | ) |
| 69 | model.train( |
| 70 | embed_size=8 if smoke else 128, |
| 71 | window_size=2 if smoke else 5, |
| 72 | workers=1, |
| 73 | iter=1 if smoke else 3, |
| 74 | ) |
| 75 | embeddings = model.get_embeddings() |
| 76 | |
| 77 | assert len(embeddings) > 0 |
| 78 | |
| 79 | if not smoke: |
| 80 | evaluate_embeddings(embeddings, FLIGHT_LABEL_PATH) |
| 81 | plot_embeddings(embeddings, FLIGHT_LABEL_PATH, show=show) |
| 82 | |
| 83 | return embeddings |
| 84 | |
| 85 | |
| 86 | if __name__ == "__main__": |
no test coverage detected