MCPcopy Index your code
hub / github.com/shenweichen/GraphEmbedding / DeepWalk

Class DeepWalk

ge/models/deepwalk.py:25–64  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

23
24
25class DeepWalk:
26 def __init__(self, graph, walk_length, num_walks, workers=1):
27
28 self.graph = graph
29 self.w2v_model = None
30 self._embeddings = {}
31
32 self.walker = RandomWalker(
33 graph, p=1, q=1, )
34 self.sentences = self.walker.simulate_walks(
35 num_walks=num_walks, walk_length=walk_length, workers=workers, verbose=1)
36
37 def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs):
38
39 kwargs["sentences"] = self.sentences
40 kwargs["min_count"] = kwargs.get("min_count", 0)
41 kwargs["vector_size"] = embed_size
42 kwargs["sg"] = 1 # skip gram
43 kwargs["hs"] = 1 # deepwalk use Hierarchical Softmax
44 kwargs["workers"] = workers
45 kwargs["window"] = window_size
46 kwargs["epochs"] = iter
47
48 print("Learning embedding vectors...")
49 model = Word2Vec(**kwargs)
50 print("Learning embedding vectors done!")
51
52 self.w2v_model = model
53 return model
54
55 def get_embeddings(self, ):
56 if self.w2v_model is None:
57 print("model not train")
58 return {}
59
60 self._embeddings = {}
61 for word in self.graph.nodes():
62 self._embeddings[word] = self.w2v_model.wv[word]
63
64 return self._embeddings

Callers 2

test_DeepWalkFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by 1

test_DeepWalkFunction · 0.72