| 3376 | ) |
| 3377 | @parametrize_idtype |
| 3378 | def test_frame_device(idtype): |
| 3379 | g = dgl.graph(([0, 1, 2], [2, 3, 1])) |
| 3380 | g.ndata["h"] = F.copy_to(F.tensor([1, 1, 1, 2], dtype=idtype), ctx=F.cpu()) |
| 3381 | g.ndata["hh"] = F.copy_to(F.ones((4, 3), dtype=idtype), ctx=F.cpu()) |
| 3382 | g.edata["h"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.cpu()) |
| 3383 | |
| 3384 | g = g.to(F.ctx()) |
| 3385 | # lazy device copy |
| 3386 | assert F.context(g._node_frames[0]._columns["h"].storage) == F.cpu() |
| 3387 | assert F.context(g._node_frames[0]._columns["hh"].storage) == F.cpu() |
| 3388 | print(g.ndata["h"]) |
| 3389 | assert F.context(g._node_frames[0]._columns["h"].storage) == F.ctx() |
| 3390 | assert F.context(g._node_frames[0]._columns["hh"].storage) == F.cpu() |
| 3391 | assert F.context(g._edge_frames[0]._columns["h"].storage) == F.cpu() |
| 3392 | |
| 3393 | # lazy device copy in subgraph |
| 3394 | sg = dgl.node_subgraph(g, [0, 1, 2]) |
| 3395 | assert F.context(sg._node_frames[0]._columns["h"].storage) == F.ctx() |
| 3396 | assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.cpu() |
| 3397 | assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu() |
| 3398 | print(sg.ndata["hh"]) |
| 3399 | assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.ctx() |
| 3400 | assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu() |
| 3401 | |
| 3402 | # back to cpu |
| 3403 | sg = sg.to(F.cpu()) |
| 3404 | assert F.context(sg._node_frames[0]._columns["h"].storage) == F.ctx() |
| 3405 | assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.ctx() |
| 3406 | assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu() |
| 3407 | print(sg.ndata["h"]) |
| 3408 | print(sg.ndata["hh"]) |
| 3409 | print(sg.edata["h"]) |
| 3410 | assert F.context(sg._node_frames[0]._columns["h"].storage) == F.cpu() |
| 3411 | assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.cpu() |
| 3412 | assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu() |
| 3413 | |
| 3414 | # set some field |
| 3415 | sg = sg.to(F.ctx()) |
| 3416 | assert F.context(sg._node_frames[0]._columns["h"].storage) == F.cpu() |
| 3417 | sg.ndata["h"][0] = 5 |
| 3418 | assert F.context(sg._node_frames[0]._columns["h"].storage) == F.ctx() |
| 3419 | assert F.context(sg._node_frames[0]._columns["hh"].storage) == F.cpu() |
| 3420 | assert F.context(sg._edge_frames[0]._columns["h"].storage) == F.cpu() |
| 3421 | |
| 3422 | # add nodes |
| 3423 | ng = dgl.add_nodes(sg, 3) |
| 3424 | assert F.context(ng._node_frames[0]._columns["h"].storage) == F.ctx() |
| 3425 | assert F.context(ng._node_frames[0]._columns["hh"].storage) == F.ctx() |
| 3426 | assert F.context(ng._edge_frames[0]._columns["h"].storage) == F.cpu() |
| 3427 | |
| 3428 | |
| 3429 | @parametrize_idtype |