MCPcopy
hub / github.com/dmlc/dgl / GCNNorm

Class GCNNorm

python/dgl/transforms/module.py:1119–1202  ·  view source on GitHub ↗

r"""Apply symmetric adjacency normalization to an input graph and save the result edge weights, as described in `Semi-Supervised Classification with Graph Convolutional Networks `__. For a heterogeneous graph, this only applies to symmetric canonical ed

Source from the content-addressed store, hash-verified

1117
1118
1119class GCNNorm(BaseTransform):
1120 r"""Apply symmetric adjacency normalization to an input graph and save the result edge
1121 weights, as described in `Semi-Supervised Classification with Graph Convolutional Networks
1122 <https://arxiv.org/abs/1609.02907>`__.
1123
1124 For a heterogeneous graph, this only applies to symmetric canonical edge types, whose source
1125 and destination node types are identical.
1126
1127 Parameters
1128 ----------
1129 eweight_name : str, optional
1130 :attr:`edata` name to retrieve and store edge weights. The edge weights are optional.
1131
1132 Example
1133 -------
1134
1135 >>> import dgl
1136 >>> import torch
1137 >>> from dgl import GCNNorm
1138 >>> transform = GCNNorm()
1139 >>> g = dgl.graph(([0, 1, 2], [0, 0, 1]))
1140
1141 Case1: Transform an unweighted graph
1142
1143 >>> g = transform(g)
1144 >>> print(g.edata['w'])
1145 tensor([0.5000, 0.7071, 0.0000])
1146
1147 Case2: Transform a weighted graph
1148
1149 >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3])
1150 >>> g = transform(g)
1151 >>> print(g.edata['w'])
1152 tensor([0.3333, 0.6667, 0.0000])
1153 """
1154
1155 def __init__(self, eweight_name="w"):
1156 self.eweight_name = eweight_name
1157
1158 def calc_etype(self, c_etype, g):
1159 r"""
1160
1161 Description
1162 -----------
1163 Get edge weights for an edge type.
1164 """
1165 ntype = c_etype[0]
1166 with g.local_scope():
1167 if self.eweight_name in g.edges[c_etype].data:
1168 g.update_all(
1169 fn.copy_e(self.eweight_name, "m"),
1170 fn.sum("m", "deg"),
1171 etype=c_etype,
1172 )
1173 deg_inv_sqrt = 1.0 / F.sqrt(g.nodes[ntype].data["deg"])
1174 g.nodes[ntype].data["w"] = F.replace_inf_with_zero(deg_inv_sqrt)
1175 g.apply_edges(
1176 lambda edge: {

Callers 2

gcnMethod · 0.85
pprMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected