(model, inputs)
| 200 | |
| 201 | |
| 202 | def _check_input_backprop(model, inputs): |
| 203 | if isinstance(inputs, list): |
| 204 | requires_grad = list() |
| 205 | for inp in inputs: |
| 206 | requires_grad.append(inp.requires_grad) |
| 207 | inp.requires_grad_(True) |
| 208 | else: |
| 209 | requires_grad = inputs.requires_grad |
| 210 | inputs.requires_grad_(True) |
| 211 | |
| 212 | out = model(inputs) |
| 213 | |
| 214 | if isinstance(out, dict): |
| 215 | out["out"].sum().backward() |
| 216 | else: |
| 217 | if isinstance(out[0], dict): |
| 218 | out[0]["scores"].sum().backward() |
| 219 | else: |
| 220 | out[0].sum().backward() |
| 221 | |
| 222 | if isinstance(inputs, list): |
| 223 | for i, inp in enumerate(inputs): |
| 224 | assert inputs[i].grad is not None |
| 225 | inp.requires_grad_(requires_grad[i]) |
| 226 | else: |
| 227 | assert inputs.grad is not None |
| 228 | inputs.requires_grad_(requires_grad) |
| 229 | |
| 230 | |
| 231 | # If 'unwrapper' is provided it will be called with the script model outputs |
no test coverage detected
searching dependent graphs…