| 17 | |
| 18 | |
| 19 | class SAGE(nn.Module): |
| 20 | def __init__(self, in_size, hid_size, out_size): |
| 21 | super().__init__() |
| 22 | self.layers = nn.ModuleList() |
| 23 | # three-layer GraphSAGE-mean |
| 24 | self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean")) |
| 25 | self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean")) |
| 26 | self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean")) |
| 27 | self.dropout = nn.Dropout(0.5) |
| 28 | self.hid_size = hid_size |
| 29 | self.out_size = out_size |
| 30 | |
| 31 | def forward(self, blocks, x): |
| 32 | h = x |
| 33 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): |
| 34 | h = layer(block, h) |
| 35 | if l != len(self.layers) - 1: |
| 36 | h = F.relu(h) |
| 37 | h = self.dropout(h) |
| 38 | return h |
| 39 | |
| 40 | def inference(self, g, device, batch_size): |
| 41 | """Conduct layer-wise inference to get all the node embeddings.""" |
| 42 | feat = g.ndata["feat"] |
| 43 | sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"]) |
| 44 | dataloader = DataLoader( |
| 45 | g, |
| 46 | torch.arange(g.num_nodes()).to(g.device), |
| 47 | sampler, |
| 48 | device=device, |
| 49 | batch_size=batch_size, |
| 50 | shuffle=False, |
| 51 | drop_last=False, |
| 52 | num_workers=0, |
| 53 | ) |
| 54 | buffer_device = torch.device("cpu") |
| 55 | pin_memory = buffer_device != device |
| 56 | |
| 57 | for l, layer in enumerate(self.layers): |
| 58 | y = torch.empty( |
| 59 | g.num_nodes(), |
| 60 | self.hid_size if l != len(self.layers) - 1 else self.out_size, |
| 61 | dtype=feat.dtype, |
| 62 | device=buffer_device, |
| 63 | pin_memory=pin_memory, |
| 64 | ) |
| 65 | feat = feat.to(device) |
| 66 | for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): |
| 67 | x = feat[input_nodes] |
| 68 | h = layer(blocks[0], x) # len(blocks) = 1 |
| 69 | if l != len(self.layers) - 1: |
| 70 | h = F.relu(h) |
| 71 | h = self.dropout(h) |
| 72 | # by design, our output nodes are contiguous |
| 73 | y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device) |
| 74 | feat = y |
| 75 | return y |
| 76 |
no outgoing calls
no test coverage detected