| 75 | |
| 76 | |
| 77 | class SAGE(nn.Module): |
| 78 | def __init__(self, in_size, hid_size, out_size): |
| 79 | super().__init__() |
| 80 | self.layers = nn.ModuleList() |
| 81 | # Three-layer GraphSAGE-gcn. |
| 82 | self.layers.append(SAGEConv(in_size, hid_size)) |
| 83 | self.layers.append(SAGEConv(hid_size, hid_size)) |
| 84 | self.layers.append(SAGEConv(hid_size, out_size)) |
| 85 | self.dropout = nn.Dropout(0.5) |
| 86 | self.hid_size = hid_size |
| 87 | self.out_size = out_size |
| 88 | |
| 89 | def forward(self, sampled_matrices, x): |
| 90 | hidden_x = x |
| 91 | for layer_idx, (layer, sampled_matrix) in enumerate( |
| 92 | zip(self.layers, sampled_matrices) |
| 93 | ): |
| 94 | hidden_x = layer(sampled_matrix, hidden_x) |
| 95 | if layer_idx != len(self.layers) - 1: |
| 96 | hidden_x = F.relu(hidden_x) |
| 97 | hidden_x = self.dropout(hidden_x) |
| 98 | return hidden_x |
| 99 | |
| 100 | |
| 101 | def multilayer_sample(A, fanouts, seeds, ndata): |