MCPcopy
hub / github.com/pytorch/vision / _check_input_backprop

Function _check_input_backprop

test/test_models.py:202–228  ·  view source on GitHub ↗
(model, inputs)

Source from the content-addressed store, hash-verified

200
201
202def _check_input_backprop(model, inputs):
203 if isinstance(inputs, list):
204 requires_grad = list()
205 for inp in inputs:
206 requires_grad.append(inp.requires_grad)
207 inp.requires_grad_(True)
208 else:
209 requires_grad = inputs.requires_grad
210 inputs.requires_grad_(True)
211
212 out = model(inputs)
213
214 if isinstance(out, dict):
215 out["out"].sum().backward()
216 else:
217 if isinstance(out[0], dict):
218 out[0]["scores"].sum().backward()
219 else:
220 out[0].sum().backward()
221
222 if isinstance(inputs, list):
223 for i, inp in enumerate(inputs):
224 assert inputs[i].grad is not None
225 inp.requires_grad_(requires_grad[i])
226 else:
227 assert inputs.grad is not None
228 inputs.requires_grad_(requires_grad)
229
230
231# If 'unwrapper' is provided it will be called with the script model outputs

Callers 9

test_inception_v3_evalFunction · 0.85
test_fasterrcnn_doubleFunction · 0.85
test_googlenet_evalFunction · 0.85
test_segmentation_modelFunction · 0.85
test_detection_modelFunction · 0.85
test_video_modelFunction · 0.85

Calls 1

backwardMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…