MCPcopy
hub / github.com/dmlc/dgl / test_frame_device

Function test_frame_device

tests/python/common/test_heterograph.py:3378–3426  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

3376)
3377@parametrize_idtype
3378def test_frame_device(idtype):
3379 g = dgl.graph(([0, 1, 2], [2, 3, 1]))
3380 g.ndata["h"] = F.copy_to(F.tensor([1, 1, 1, 2], dtype=idtype), ctx=F.cpu())
3381 g.ndata["hh"] = F.copy_to(F.ones((4, 3), dtype=idtype), ctx=F.cpu())
3382 g.edata["h"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.cpu())
3383
3384 g = g.to(F.ctx())
3385 # lazy device copy
3386 assert F.context(g._node_frames[0]._columns["h"].storage) == F.cpu()
3387 assert F.context(g._node_frames[0]._columns["hh"].storage) == F.cpu()
3388 print(g.ndata["h"])
3389 assert F.context(g._node_frames[0]._columns["h"].storage) == F.ctx()
3390 assert F.context(g._node_frames[0]._columns["hh"].storage) == F.cpu()
3391 assert F.context(g._edge_frames[0]._columns["h"].storage) == F.cpu()
3392
3393 # lazy device copy in subgraph
3394 sg = dgl.node_subgraph(g, [0, 1, 2])
3395 assert F.context(sg._node_frames[0]._columns["h"].storage) == F.ctx()
3396 assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.cpu()
3397 assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu()
3398 print(sg.ndata["hh"])
3399 assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.ctx()
3400 assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu()
3401
3402 # back to cpu
3403 sg = sg.to(F.cpu())
3404 assert F.context(sg._node_frames[0]._columns["h"].storage) == F.ctx()
3405 assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.ctx()
3406 assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu()
3407 print(sg.ndata["h"])
3408 print(sg.ndata["hh"])
3409 print(sg.edata["h"])
3410 assert F.context(sg._node_frames[0]._columns["h"].storage) == F.cpu()
3411 assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.cpu()
3412 assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu()
3413
3414 # set some field
3415 sg = sg.to(F.ctx())
3416 assert F.context(sg._node_frames[0]._columns["h"].storage) == F.cpu()
3417 sg.ndata["h"][0] = 5
3418 assert F.context(sg._node_frames[0]._columns["h"].storage) == F.ctx()
3419 assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.cpu()
3420 assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu()
3421
3422 # add nodes
3423 ng = dgl.add_nodes(sg, 3)
3424 assert F.context(ng._node_frames[0]._columns["h"].storage) == F.ctx()
3425 assert F.context(ng._node_frames[0]._columns["hh"].storage) == F.ctx()
3426 assert F.context(ng._edge_frames[0]._columns["h"].storage) == F.cpu()
3427
3428
3429@parametrize_idtype

Callers

nothing calls this directly

Calls 8

contextMethod · 0.80
graphMethod · 0.45
copy_toMethod · 0.45
cpuMethod · 0.45
toMethod · 0.45
ctxMethod · 0.45
node_subgraphMethod · 0.45
add_nodesMethod · 0.45

Tested by

no test coverage detected