MCPcopy
hub / github.com/hustvl/Vim / _test_model

Method _test_model

det/tests/test_export_torchscript.py:165–200  ·  view source on GitHub ↗
(self, config_path, inference_func, batch=1)

Source from the content-addressed store, hash-verified

163 return testing_devices
164
165 def _test_model(self, config_path, inference_func, batch=1):
166 model = model_zoo.get(config_path, trained=True)
167 image = get_sample_coco_image()
168 inputs = tuple(image.clone() for _ in range(batch))
169
170 wrapper = TracingAdapter(model, inputs, inference_func)
171 wrapper.eval()
172 with torch.no_grad():
173 # trace with smaller images, and the trace must still work
174 trace_inputs = tuple(
175 nn.functional.interpolate(image, scale_factor=random.uniform(0.5, 0.7))
176 for _ in range(batch)
177 )
178 traced_model = torch.jit.trace(wrapper, trace_inputs)
179
180 testing_devices = self._get_device_casting_test_cases(model)
181 # save and load back the model in order to show traceback of TorchScript
182 with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
183 basename = "model"
184 jitfile = f"{d}/{basename}.jit"
185 torch.jit.save(traced_model, jitfile)
186 traced_model = torch.jit.load(jitfile)
187
188 if any(device and "cuda" in device for device in testing_devices):
189 self._check_torchscript_no_hardcoded_device(jitfile, d, "cuda")
190
191 for device in testing_devices:
192 print(f"Testing casting to {device} for inference (traced on {model.device}) ...")
193 with torch.no_grad():
194 outputs = inference_func(copy.deepcopy(model).to(device), *inputs)
195 traced_outputs = wrapper.outputs_schema(traced_model.to(device)(*inputs))
196 if batch > 1:
197 for output, traced_output in zip(outputs, traced_outputs):
198 assert_instances_allclose(output, traced_output, size_as_tensor=True)
199 else:
200 assert_instances_allclose(outputs, traced_outputs, size_as_tensor=True)
201
202 @skipIfOnCPUCI
203 def testMaskRCNNFPN_batched(self):

Callers 6

testMaskRCNNFPNMethod · 0.95
testMaskRCNNC4Method · 0.95
testCascadeRCNNMethod · 0.95
testRetinaNetMethod · 0.95

Calls 11

get_sample_coco_imageFunction · 0.90
TracingAdapterClass · 0.90
printFunction · 0.85
getMethod · 0.45
cloneMethod · 0.45
saveMethod · 0.45
loadMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected