MCPcopy
hub / github.com/dmlc/dgl / test_half_spmm

Function test_half_spmm

tests/python/common/ops/test_ops.py:193–223  ·  view source on GitHub ↗
(idtype, dtype, rtol, atol)

Source from the content-addressed store, hash-verified

191 [(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.0)],
192)
193def test_half_spmm(idtype, dtype, rtol, atol):
194 if F._default_context_str == "cpu" and dtype == torch.float16:
195 pytest.skip("float16 is not supported on CPU.")
196 if (
197 F._default_context_str == "gpu"
198 and dtype == torch.bfloat16
199 and not torch.cuda.is_bf16_supported()
200 ):
201 pytest.skip("BF16 is not supported.")
202
203 # make sure the spmm result is < 512 to match the rtol/atol we set.
204 g = dgl.graph(
205 (torch.arange(900), torch.tensor([0] * 900)),
206 idtype=idtype,
207 device=F.ctx(),
208 )
209 feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(F.ctx())
210 feat_half = feat_fp32.to(dtype)
211
212 # test SpMMCSR
213 g = g.formats(["csc"])
214 res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
215 res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
216 assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)
217
218 # test SpMMCOO
219 # TODO(Xin): half-precision SpMMCoo is temporally disabled.
220 # g = g.formats(['coo'])
221 # res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
222 # res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
223 # assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)
224
225
226@pytest.mark.parametrize("g", graphs)

Callers

nothing calls this directly

Calls 6

num_src_nodesMethod · 0.80
graphMethod · 0.45
ctxMethod · 0.45
toMethod · 0.45
formatsMethod · 0.45
floatMethod · 0.45

Tested by

no test coverage detected