MCPcopy
hub / github.com/hustvl/Vim / assert_instances_allclose

Function assert_instances_allclose

det/detectron2/utils/testing.py:95–139  ·  view source on GitHub ↗

Args: input, other (Instances): size_as_tensor: compare image_size of the Instances as tensors (instead of tuples). Useful for comparing outputs of tracing.

(input, other, *, rtol=1e-5, msg="", size_as_tensor=False)

Source from the content-addressed store, hash-verified

93
94
95def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False):
96 """
97 Args:
98 input, other (Instances):
99 size_as_tensor: compare image_size of the Instances as tensors (instead of tuples).
100 Useful for comparing outputs of tracing.
101 """
102 if not isinstance(input, Instances):
103 input = convert_scripted_instances(input)
104 if not isinstance(other, Instances):
105 other = convert_scripted_instances(other)
106
107 if not msg:
108 msg = "Two Instances are different! "
109 else:
110 msg = msg.rstrip() + " "
111
112 size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!"
113 if size_as_tensor:
114 assert torch.equal(
115 torch.tensor(input.image_size), torch.tensor(other.image_size)
116 ), size_error_msg
117 else:
118 assert input.image_size == other.image_size, size_error_msg
119 fields = sorted(input.get_fields().keys())
120 fields_other = sorted(other.get_fields().keys())
121 assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!"
122
123 for f in fields:
124 val1, val2 = input.get(f), other.get(f)
125 if isinstance(val1, (Boxes, ROIMasks)):
126 # boxes in the range of O(100) and can have a larger tolerance
127 assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), (
128 msg + f"Field {f} differs too much!"
129 )
130 elif isinstance(val1, torch.Tensor):
131 if val1.dtype.is_floating_point:
132 mag = torch.abs(val1).max().cpu().item()
133 assert torch.allclose(val1, val2, atol=mag * rtol), (
134 msg + f"Field {f} differs too much!"
135 )
136 else:
137 assert torch.equal(val1, val2), msg + f"Field {f} is different!"
138 else:
139 raise ValueError(f"Don't know how to compare type {type(val1)}")
140
141
142def reload_script_model(module):

Calls 4

maxMethod · 0.80
get_fieldsMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected