MCPcopy
hub / github.com/dmlc/dgl / test_stack_reduce

Function test_stack_reduce

tests/python/common/test_heterograph.py:2176–2209  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

2174
2175@parametrize_idtype
2176def test_stack_reduce(idtype):
2177 # edges = {
2178 # 'follows': ([0, 1], [1, 2]),
2179 # 'plays': ([0, 1, 2, 1], [0, 0, 1, 1]),
2180 # 'wishes': ([0, 2], [1, 0]),
2181 # 'develops': ([0, 1], [0, 1]),
2182 # }
2183 g = create_test_heterograph(idtype)
2184 g.nodes["user"].data["h"] = F.randn((3, 200))
2185
2186 def rfunc(nodes):
2187 return {"y": F.sum(nodes.mailbox["m"], 1)}
2188
2189 def rfunc2(nodes):
2190 return {"y": F.max(nodes.mailbox["m"], 1)}
2191
2192 def mfunc(edges):
2193 return {"m": edges.src["h"]}
2194
2195 g.multi_update_all(
2196 {"plays": (mfunc, rfunc), "wishes": (mfunc, rfunc2)}, "stack"
2197 )
2198 assert g.nodes["game"].data["y"].shape == (
2199 g.num_nodes("game"),
2200 2,
2201 200,
2202 )
2203 # only one type-wise update_all, stack still adds one dimension
2204 g.multi_update_all({"plays": (mfunc, rfunc)}, "stack")
2205 assert g.nodes["game"].data["y"].shape == (
2206 g.num_nodes("game"),
2207 1,
2208 200,
2209 )
2210
2211
2212@parametrize_idtype

Callers

nothing calls this directly

Calls 3

multi_update_allMethod · 0.80
create_test_heterographFunction · 0.70
num_nodesMethod · 0.45

Tested by

no test coverage detected