MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / assert_close

Function assert_close

triton_kernels/testing.py:21–80  ·  view source on GitHub ↗
(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True)

Source from the content-addressed store, hash-verified

19
20
21def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
22 if tri.dtype.itemsize == 1:
23 ref_as_type = ref.to(tri.dtype)
24 if ref.dtype == tri.dtype:
25 assert torch.all(ref_as_type == tri)
26 return
27 ref = ref_as_type
28
29 if ref.numel() == 0:
30 return
31
32 if maxtol is None:
33 maxtol = 2e-2
34 if rmstol is None:
35 rmstol = 4e-3
36 """
37 Compare reference values against obtained values.
38 """
39
40 # cast to float32:
41 ref = ref.to(torch.float32).detach()
42 tri = tri.to(torch.float32).detach()
43 assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}"
44
45 # deal with infinite elements:
46 inf_mask_ref = torch.isinf(ref)
47 inf_mask_tri = torch.isinf(tri)
48 assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements"
49 refn = torch.where(inf_mask_ref, 0, ref)
50 trin = torch.where(inf_mask_tri, 0, tri)
51
52 # normalise so that RMS calculation doesn't overflow:
53 eps = 1.0e-30
54 multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
55 refn *= multiplier
56 trin *= multiplier
57
58 ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
59
60 rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
61 max_err = torch.max(rel_err).item()
62 rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
63
64 if verbose:
65 print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol))
66 print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol))
67
68 if max_err > maxtol:
69 bad_idxs = torch.nonzero(rel_err > maxtol)
70 num_nonzero = bad_idxs.size(0)
71 bad_idxs = bad_idxs[:1000]
72 print("%d / %d mismatched elements (shape = %s) at coords %s" %
73 (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()))
74
75 bad_idxs = bad_idxs.unbind(-1)
76 print("ref values: ", ref[tuple(bad_idxs)].cpu())
77 print("tri values: ", tri[tuple(bad_idxs)].cpu())
78

Callers

nothing calls this directly

Calls 9

absMethod · 0.80
sqrtMethod · 0.80
meanMethod · 0.80
unbindMethod · 0.80
toMethod · 0.45
numelMethod · 0.45
whereMethod · 0.45
maxMethod · 0.45
sizeMethod · 0.45

Tested by

no test coverage detected