MCPcopy
hub / github.com/dmlc/dgl / test_subframes

Function test_subframes

tests/python/common/test_subgraph.py:794–826  ·  view source on GitHub ↗
(parent_idx_device, child_device)

Source from the content-addressed store, hash-verified

792)
793@pytest.mark.parametrize("child_device", [F.cpu(), F.cuda()])
794def test_subframes(parent_idx_device, child_device):
795 parent_device, idx_device = parent_idx_device
796 g = dgl.graph(
797 (F.tensor([1, 2, 3], dtype=F.int64), F.tensor([2, 3, 4], dtype=F.int64))
798 )
799 print(g.device)
800 g.ndata["x"] = F.randn((5, 4))
801 g.edata["a"] = F.randn((3, 6))
802 idx = F.tensor([1, 2], dtype=F.int64)
803 if parent_device == "cuda":
804 g = g.to(F.cuda())
805 elif parent_device == "uva":
806 if F.backend_name != "pytorch":
807 pytest.skip("UVA only supported for PyTorch")
808 g = g.to(F.cpu())
809 g.create_formats_()
810 g.pin_memory_()
811 elif parent_device == "cpu":
812 g = g.to(F.cpu())
813 idx = F.copy_to(idx, idx_device)
814 sg = g.sample_neighbors(idx, 2).to(child_device)
815 assert sg.device == F.context(sg.ndata["x"])
816 assert sg.device == F.context(sg.edata["a"])
817 assert sg.device == child_device
818 if parent_device != "uva":
819 sg = g.to(child_device).sample_neighbors(
820 F.copy_to(idx, child_device), 2
821 )
822 assert sg.device == F.context(sg.ndata["x"])
823 assert sg.device == F.context(sg.edata["a"])
824 assert sg.device == child_device
825 if parent_device == "uva":
826 g.unpin_memory_()
827
828
829@unittest.skipIf(

Callers

nothing calls this directly

Calls 10

cudaMethod · 0.80
contextMethod · 0.80
graphMethod · 0.45
toMethod · 0.45
cpuMethod · 0.45
create_formats_Method · 0.45
pin_memory_Method · 0.45
copy_toMethod · 0.45
sample_neighborsMethod · 0.45
unpin_memory_Method · 0.45

Tested by

no test coverage detected