| 31 | |
| 32 | |
| 33 | class SAGE(nn.Module): |
| 34 | def __init__(self, in_size, hidden_size, out_size): |
| 35 | super().__init__() |
| 36 | self.layers = nn.ModuleList() |
| 37 | # Two-layer GraphSAGE-gcn. |
| 38 | self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "gcn")) |
| 39 | self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "gcn")) |
| 40 | self.dropout = nn.Dropout(0.5) |
| 41 | |
| 42 | def forward(self, graph, x): |
| 43 | hidden_x = x |
| 44 | for layer_idx, layer in enumerate(self.layers): |
| 45 | hidden_x = layer(graph, hidden_x) |
| 46 | is_last_layer = layer_idx == len(self.layers) - 1 |
| 47 | if not is_last_layer: |
| 48 | hidden_x = F.relu(hidden_x) |
| 49 | hidden_x = self.dropout(hidden_x) |
| 50 | return hidden_x |
| 51 | |
| 52 | |
| 53 | def evaluate(g, features, labels, mask, model): |
no outgoing calls
no test coverage detected