(self, *inputs, **kwargs)
| 185 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) |
| 186 | |
| 187 | def forward(self, *inputs, **kwargs): # pragma: no cover |
| 188 | self._sync_params() |
| 189 | if self.device_ids: |
| 190 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| 191 | if len(self.device_ids) == 1: |
| 192 | # -------------- |
| 193 | # LIGHTNING MOD |
| 194 | # -------------- |
| 195 | # normal |
| 196 | # output = self.module(*inputs[0], **kwargs[0]) |
| 197 | # lightning |
| 198 | if self.module.training: |
| 199 | output = self.module.training_step(*inputs[0], **kwargs[0]) |
| 200 | elif self.module.testing: |
| 201 | output = self.module.test_step(*inputs[0], **kwargs[0]) |
| 202 | else: |
| 203 | output = self.module.validation_step(*inputs[0], **kwargs[0]) |
| 204 | else: |
| 205 | outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) |
| 206 | output = self.gather(outputs, self.output_device) |
| 207 | else: |
| 208 | # normal |
| 209 | output = self.module(*inputs, **kwargs) |
| 210 | |
| 211 | if torch.is_grad_enabled(): |
| 212 | # We'll return the output object verbatim since it is a freeform |
| 213 | # object. We need to find any tensors in this object, though, |
| 214 | # because we need to figure out which parameters were used during |
| 215 | # this forward pass, to ensure we short circuit reduction for any |
| 216 | # unused parameters. Only if `find_unused_parameters` is set. |
| 217 | if self.find_unused_parameters: |
| 218 | self.reducer.prepare_for_backward(list(_find_tensors(output))) |
| 219 | else: |
| 220 | self.reducer.prepare_for_backward([]) |
| 221 | return output |
| 222 | |
| 223 | |
| 224 | class DP(DataParallel): |
nothing calls this directly
no test coverage detected