MCPcopy
hub / github.com/dmlc/dgl / test_to_device

Function test_to_device

tests/python/common/test_heterograph.py:1135–1173  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

1133)
1134@parametrize_idtype
1135def test_to_device(idtype):
1136 # TODO: rewrite this test case to accept different graphs so we
1137 # can test reverse graph and batched graph
1138 g = create_test_heterograph(idtype)
1139 g.nodes["user"].data["h"] = F.ones((3, 5))
1140 g.nodes["game"].data["i"] = F.ones((2, 5))
1141 g.edges["plays"].data["e"] = F.ones((4, 4))
1142 assert g.device == F.ctx()
1143 g = g.to(F.cpu())
1144 assert g.device == F.cpu()
1145 assert F.context(g.nodes["user"].data["h"]) == F.cpu()
1146 assert F.context(g.nodes["game"].data["i"]) == F.cpu()
1147 assert F.context(g.edges["plays"].data["e"]) == F.cpu()
1148 for ntype in g.ntypes:
1149 assert F.context(g.batch_num_nodes(ntype)) == F.cpu()
1150 for etype in g.canonical_etypes:
1151 assert F.context(g.batch_num_edges(etype)) == F.cpu()
1152
1153 if F.is_cuda_available():
1154 g1 = g.to(F.cuda())
1155 assert g1.device == F.cuda()
1156 assert F.context(g1.nodes["user"].data["h"]) == F.cuda()
1157 assert F.context(g1.nodes["game"].data["i"]) == F.cuda()
1158 assert F.context(g1.edges["plays"].data["e"]) == F.cuda()
1159 for ntype in g1.ntypes:
1160 assert F.context(g1.batch_num_nodes(ntype)) == F.cuda()
1161 for etype in g1.canonical_etypes:
1162 assert F.context(g1.batch_num_edges(etype)) == F.cuda()
1163 assert F.context(g.nodes["user"].data["h"]) == F.cpu()
1164 assert F.context(g.nodes["game"].data["i"]) == F.cpu()
1165 assert F.context(g.edges["plays"].data["e"]) == F.cpu()
1166 for ntype in g.ntypes:
1167 assert F.context(g.batch_num_nodes(ntype)) == F.cpu()
1168 for etype in g.canonical_etypes:
1169 assert F.context(g.batch_num_edges(etype)) == F.cpu()
1170 with pytest.raises(DGLError):
1171 g1.nodes["user"].data["h"] = F.copy_to(F.ones((3, 5)), F.cpu())
1172 with pytest.raises(DGLError):
1173 g1.edges["plays"].data["e"] = F.copy_to(F.ones((4, 4)), F.cpu())
1174
1175
1176@unittest.skipIf(

Callers

nothing calls this directly

Calls 9

contextMethod · 0.80
batch_num_nodesMethod · 0.80
batch_num_edgesMethod · 0.80
cudaMethod · 0.80
create_test_heterographFunction · 0.70
ctxMethod · 0.45
toMethod · 0.45
cpuMethod · 0.45
copy_toMethod · 0.45

Tested by

no test coverage detected