MCPcopy Index your code
hub / github.com/dmlc/dgl / SAGE

Class SAGE

examples/pytorch/graphsage/node_classification.py:19–75  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

17
18
19class SAGE(nn.Module):
20 def __init__(self, in_size, hid_size, out_size):
21 super().__init__()
22 self.layers = nn.ModuleList()
23 # three-layer GraphSAGE-mean
24 self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
25 self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
26 self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
27 self.dropout = nn.Dropout(0.5)
28 self.hid_size = hid_size
29 self.out_size = out_size
30
31 def forward(self, blocks, x):
32 h = x
33 for l, (layer, block) in enumerate(zip(self.layers, blocks)):
34 h = layer(block, h)
35 if l != len(self.layers) - 1:
36 h = F.relu(h)
37 h = self.dropout(h)
38 return h
39
40 def inference(self, g, device, batch_size):
41 """Conduct layer-wise inference to get all the node embeddings."""
42 feat = g.ndata["feat"]
43 sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
44 dataloader = DataLoader(
45 g,
46 torch.arange(g.num_nodes()).to(g.device),
47 sampler,
48 device=device,
49 batch_size=batch_size,
50 shuffle=False,
51 drop_last=False,
52 num_workers=0,
53 )
54 buffer_device = torch.device("cpu")
55 pin_memory = buffer_device != device
56
57 for l, layer in enumerate(self.layers):
58 y = torch.empty(
59 g.num_nodes(),
60 self.hid_size if l != len(self.layers) - 1 else self.out_size,
61 dtype=feat.dtype,
62 device=buffer_device,
63 pin_memory=pin_memory,
64 )
65 feat = feat.to(device)
66 for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
67 x = feat[input_nodes]
68 h = layer(blocks[0], x) # len(blocks) = 1
69 if l != len(self.layers) - 1:
70 h = F.relu(h)
71 h = self.dropout(h)
72 # by design, our output nodes are contiguous
73 y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
74 feat = y
75 return y
76

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected