| 19 | |
| 20 | |
| 21 | def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): |
| 22 | if tri.dtype.itemsize == 1: |
| 23 | ref_as_type = ref.to(tri.dtype) |
| 24 | if ref.dtype == tri.dtype: |
| 25 | assert torch.all(ref_as_type == tri) |
| 26 | return |
| 27 | ref = ref_as_type |
| 28 | |
| 29 | if ref.numel() == 0: |
| 30 | return |
| 31 | |
| 32 | if maxtol is None: |
| 33 | maxtol = 2e-2 |
| 34 | if rmstol is None: |
| 35 | rmstol = 4e-3 |
| 36 | """ |
| 37 | Compare reference values against obtained values. |
| 38 | """ |
| 39 | |
| 40 | # cast to float32: |
| 41 | ref = ref.to(torch.float32).detach() |
| 42 | tri = tri.to(torch.float32).detach() |
| 43 | assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}" |
| 44 | |
| 45 | # deal with infinite elements: |
| 46 | inf_mask_ref = torch.isinf(ref) |
| 47 | inf_mask_tri = torch.isinf(tri) |
| 48 | assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements" |
| 49 | refn = torch.where(inf_mask_ref, 0, ref) |
| 50 | trin = torch.where(inf_mask_tri, 0, tri) |
| 51 | |
| 52 | # normalise so that RMS calculation doesn't overflow: |
| 53 | eps = 1.0e-30 |
| 54 | multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) |
| 55 | refn *= multiplier |
| 56 | trin *= multiplier |
| 57 | |
| 58 | ref_rms = torch.sqrt(torch.square(refn).mean()) + eps |
| 59 | |
| 60 | rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) |
| 61 | max_err = torch.max(rel_err).item() |
| 62 | rms_err = torch.sqrt(torch.square(rel_err).mean()).item() |
| 63 | |
| 64 | if verbose: |
| 65 | print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol)) |
| 66 | print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol)) |
| 67 | |
| 68 | if max_err > maxtol: |
| 69 | bad_idxs = torch.nonzero(rel_err > maxtol) |
| 70 | num_nonzero = bad_idxs.size(0) |
| 71 | bad_idxs = bad_idxs[:1000] |
| 72 | print("%d / %d mismatched elements (shape = %s) at coords %s" % |
| 73 | (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())) |
| 74 | |
| 75 | bad_idxs = bad_idxs.unbind(-1) |
| 76 | print("ref values: ", ref[tuple(bad_idxs)].cpu()) |
| 77 | print("tri values: ", tri[tuple(bad_idxs)].cpu()) |
| 78 | |