MCPcopy Index your code
hub / github.com/MoonInTheRiver/DiffSinger / forward

Method forward

utils/pl_utils.py:229–251  ·  view source on GitHub ↗
(self, *inputs, **kwargs)

Source from the content-addressed store, hash-verified

227 """
228
229 def forward(self, *inputs, **kwargs):
230 if not self.device_ids:
231 return self.module(*inputs, **kwargs)
232
233 for t in itertools.chain(self.module.parameters(), self.module.buffers()):
234 if t.device != self.src_device_obj:
235 raise RuntimeError("module must have its parameters and buffers "
236 "on device {} (device_ids[0]) but found one of "
237 "them on device: {}".format(self.src_device_obj, t.device))
238
239 inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
240 if len(self.device_ids) == 1:
241 # lightning
242 if self.module.training:
243 return self.module.training_step(*inputs[0], **kwargs[0])
244 elif self.module.testing:
245 return self.module.test_step(*inputs[0], **kwargs[0])
246 else:
247 return self.module.validation_step(*inputs[0], **kwargs[0])
248
249 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
250 outputs = self.parallel_apply(replicas, inputs, kwargs)
251 return self.gather(outputs, self.output_device)
252
253 def parallel_apply(self, replicas, inputs, kwargs):
254 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

Callers

nothing calls this directly

Calls 4

parallel_applyMethod · 0.95
training_stepMethod · 0.80
test_stepMethod · 0.45
validation_stepMethod · 0.45

Tested by

no test coverage detected