MCPcopy
hub / github.com/dmlc/dgl / test_module_sign

Function test_module_sign

tests/python/common/transforms/test_transform.py:3142–3235  ·  view source on GitHub ↗
(g)

Source from the content-addressed store, hash-verified

3140)
3141@pytest.mark.parametrize("g", get_cases(["has_scalar_e_feature"]))
3142def test_module_sign(g):
3143 import torch
3144
3145 atol = 1e-06
3146
3147 ctx = F.ctx()
3148 g = g.to(ctx)
3149 adj = g.adj_external(transpose=True, scipy_fmt="coo").todense()
3150 adj = torch.tensor(adj).float().to(ctx)
3151
3152 weight_adj = (
3153 g.adj_external(transpose=True, scipy_fmt="coo").astype(float).todense()
3154 )
3155 weight_adj = torch.tensor(weight_adj).float().to(ctx)
3156 src, dst = g.edges()
3157 src, dst = src.long(), dst.long()
3158 weight_adj[dst, src] = g.edata["scalar_w"]
3159
3160 # raw
3161 transform = dgl.SIGNDiffusion(k=1, in_feat_name="h", diffuse_op="raw")
3162 g = transform(g)
3163 target = torch.matmul(adj, g.ndata["h"])
3164 assert torch.allclose(g.ndata["out_feat_1"], target, atol=atol)
3165
3166 transform = dgl.SIGNDiffusion(
3167 k=1, in_feat_name="h", eweight_name="scalar_w", diffuse_op="raw"
3168 )
3169 g = transform(g)
3170 target = torch.matmul(weight_adj, g.ndata["h"])
3171 assert torch.allclose(g.ndata["out_feat_1"], target, atol=atol)
3172
3173 # rw
3174 adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj)
3175 transform = dgl.SIGNDiffusion(k=1, in_feat_name="h", diffuse_op="rw")
3176 g = transform(g)
3177 target = torch.matmul(adj_rw, g.ndata["h"])
3178 assert torch.allclose(g.ndata["out_feat_1"], target, atol=atol)
3179
3180 weight_adj_rw = torch.matmul(
3181 torch.diag(1 / weight_adj.sum(dim=1)), weight_adj
3182 )
3183 transform = dgl.SIGNDiffusion(
3184 k=1, in_feat_name="h", eweight_name="scalar_w", diffuse_op="rw"
3185 )
3186 g = transform(g)
3187 target = torch.matmul(weight_adj_rw, g.ndata["h"])
3188 assert torch.allclose(g.ndata["out_feat_1"], target, atol=atol)
3189
3190 # gcn
3191 raw_eweight = g.edata["scalar_w"]
3192 gcn_norm = dgl.GCNNorm()
3193 g = gcn_norm(g)
3194 adj_gcn = adj.clone()
3195 adj_gcn[dst, src] = g.edata.pop("w")
3196 transform = dgl.SIGNDiffusion(k=1, in_feat_name="h", diffuse_op="gcn")
3197 g = transform(g)
3198 target = torch.matmul(adj_gcn, g.ndata["h"])
3199 assert torch.allclose(g.ndata["out_feat_1"], target, atol=atol)

Callers

nothing calls this directly

Calls 9

transformFunction · 0.85
adj_externalMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45
floatMethod · 0.45
astypeMethod · 0.45
edgesMethod · 0.45
longMethod · 0.45
cloneMethod · 0.45

Tested by

no test coverage detected