| 191 | [(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.0)], |
| 192 | ) |
| 193 | def test_half_spmm(idtype, dtype, rtol, atol): |
| 194 | if F._default_context_str == "cpu" and dtype == torch.float16: |
| 195 | pytest.skip("float16 is not supported on CPU.") |
| 196 | if ( |
| 197 | F._default_context_str == "gpu" |
| 198 | and dtype == torch.bfloat16 |
| 199 | and not torch.cuda.is_bf16_supported() |
| 200 | ): |
| 201 | pytest.skip("BF16 is not supported.") |
| 202 | |
| 203 | # make sure the spmm result is < 512 to match the rtol/atol we set. |
| 204 | g = dgl.graph( |
| 205 | (torch.arange(900), torch.tensor([0] * 900)), |
| 206 | idtype=idtype, |
| 207 | device=F.ctx(), |
| 208 | ) |
| 209 | feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(F.ctx()) |
| 210 | feat_half = feat_fp32.to(dtype) |
| 211 | |
| 212 | # test SpMMCSR |
| 213 | g = g.formats(["csc"]) |
| 214 | res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0] |
| 215 | res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float() |
| 216 | assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol) |
| 217 | |
| 218 | # test SpMMCOO |
| 219 | # TODO(Xin): half-precision SpMMCoo is temporally disabled. |
| 220 | # g = g.formats(['coo']) |
| 221 | # res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0] |
| 222 | # res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float() |
| 223 | # assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol) |
| 224 | |
| 225 | |
| 226 | @pytest.mark.parametrize("g", graphs) |