r"""Apply symmetric adjacency normalization to an input graph and save the result edge weights, as described in `Semi-Supervised Classification with Graph Convolutional Networks `__. For a heterogeneous graph, this only applies to symmetric canonical ed
| 1117 | |
| 1118 | |
| 1119 | class GCNNorm(BaseTransform): |
| 1120 | r"""Apply symmetric adjacency normalization to an input graph and save the result edge |
| 1121 | weights, as described in `Semi-Supervised Classification with Graph Convolutional Networks |
| 1122 | <https://arxiv.org/abs/1609.02907>`__. |
| 1123 | |
| 1124 | For a heterogeneous graph, this only applies to symmetric canonical edge types, whose source |
| 1125 | and destination node types are identical. |
| 1126 | |
| 1127 | Parameters |
| 1128 | ---------- |
| 1129 | eweight_name : str, optional |
| 1130 | :attr:`edata` name to retrieve and store edge weights. The edge weights are optional. |
| 1131 | |
| 1132 | Example |
| 1133 | ------- |
| 1134 | |
| 1135 | >>> import dgl |
| 1136 | >>> import torch |
| 1137 | >>> from dgl import GCNNorm |
| 1138 | >>> transform = GCNNorm() |
| 1139 | >>> g = dgl.graph(([0, 1, 2], [0, 0, 1])) |
| 1140 | |
| 1141 | Case1: Transform an unweighted graph |
| 1142 | |
| 1143 | >>> g = transform(g) |
| 1144 | >>> print(g.edata['w']) |
| 1145 | tensor([0.5000, 0.7071, 0.0000]) |
| 1146 | |
| 1147 | Case2: Transform a weighted graph |
| 1148 | |
| 1149 | >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3]) |
| 1150 | >>> g = transform(g) |
| 1151 | >>> print(g.edata['w']) |
| 1152 | tensor([0.3333, 0.6667, 0.0000]) |
| 1153 | """ |
| 1154 | |
| 1155 | def __init__(self, eweight_name="w"): |
| 1156 | self.eweight_name = eweight_name |
| 1157 | |
| 1158 | def calc_etype(self, c_etype, g): |
| 1159 | r""" |
| 1160 | |
| 1161 | Description |
| 1162 | ----------- |
| 1163 | Get edge weights for an edge type. |
| 1164 | """ |
| 1165 | ntype = c_etype[0] |
| 1166 | with g.local_scope(): |
| 1167 | if self.eweight_name in g.edges[c_etype].data: |
| 1168 | g.update_all( |
| 1169 | fn.copy_e(self.eweight_name, "m"), |
| 1170 | fn.sum("m", "deg"), |
| 1171 | etype=c_etype, |
| 1172 | ) |
| 1173 | deg_inv_sqrt = 1.0 / F.sqrt(g.nodes[ntype].data["deg"]) |
| 1174 | g.nodes[ntype].data["w"] = F.replace_inf_with_zero(deg_inv_sqrt) |
| 1175 | g.apply_edges( |
| 1176 | lambda edge: { |