MCPcopy Index your code
hub / github.com/dmlc/dgl / SAGE

Class SAGE

examples/core/graphsage/node_classification.py:33–50  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

31
32
33class SAGE(nn.Module):
34 def __init__(self, in_size, hidden_size, out_size):
35 super().__init__()
36 self.layers = nn.ModuleList()
37 # Two-layer GraphSAGE-gcn.
38 self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "gcn"))
39 self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "gcn"))
40 self.dropout = nn.Dropout(0.5)
41
42 def forward(self, graph, x):
43 hidden_x = x
44 for layer_idx, layer in enumerate(self.layers):
45 hidden_x = layer(graph, hidden_x)
46 is_last_layer = layer_idx == len(self.layers) - 1
47 if not is_last_layer:
48 hidden_x = F.relu(hidden_x)
49 hidden_x = self.dropout(hidden_x)
50 return hidden_x
51
52
53def evaluate(g, features, labels, mask, model):

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected