(self, *inputs, **kwargs)
| 227 | """ |
| 228 | |
| 229 | def forward(self, *inputs, **kwargs): |
| 230 | if not self.device_ids: |
| 231 | return self.module(*inputs, **kwargs) |
| 232 | |
| 233 | for t in itertools.chain(self.module.parameters(), self.module.buffers()): |
| 234 | if t.device != self.src_device_obj: |
| 235 | raise RuntimeError("module must have its parameters and buffers " |
| 236 | "on device {} (device_ids[0]) but found one of " |
| 237 | "them on device: {}".format(self.src_device_obj, t.device)) |
| 238 | |
| 239 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| 240 | if len(self.device_ids) == 1: |
| 241 | # lightning |
| 242 | if self.module.training: |
| 243 | return self.module.training_step(*inputs[0], **kwargs[0]) |
| 244 | elif self.module.testing: |
| 245 | return self.module.test_step(*inputs[0], **kwargs[0]) |
| 246 | else: |
| 247 | return self.module.validation_step(*inputs[0], **kwargs[0]) |
| 248 | |
| 249 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) |
| 250 | outputs = self.parallel_apply(replicas, inputs, kwargs) |
| 251 | return self.gather(outputs, self.output_device) |
| 252 | |
| 253 | def parallel_apply(self, replicas, inputs, kwargs): |
| 254 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) |
nothing calls this directly
no test coverage detected