MCPcopy Index your code
hub / github.com/MoonInTheRiver/DiffSinger / forward

Method forward

utils/pl_utils.py:187–221  ·  view source on GitHub ↗
(self, *inputs, **kwargs)

Source from the content-addressed store, hash-verified

185 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
186
187 def forward(self, *inputs, **kwargs): # pragma: no cover
188 self._sync_params()
189 if self.device_ids:
190 inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
191 if len(self.device_ids) == 1:
192 # --------------
193 # LIGHTNING MOD
194 # --------------
195 # normal
196 # output = self.module(*inputs[0], **kwargs[0])
197 # lightning
198 if self.module.training:
199 output = self.module.training_step(*inputs[0], **kwargs[0])
200 elif self.module.testing:
201 output = self.module.test_step(*inputs[0], **kwargs[0])
202 else:
203 output = self.module.validation_step(*inputs[0], **kwargs[0])
204 else:
205 outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
206 output = self.gather(outputs, self.output_device)
207 else:
208 # normal
209 output = self.module(*inputs, **kwargs)
210
211 if torch.is_grad_enabled():
212 # We'll return the output object verbatim since it is a freeform
213 # object. We need to find any tensors in this object, though,
214 # because we need to figure out which parameters were used during
215 # this forward pass, to ensure we short circuit reduction for any
216 # unused parameters. Only if `find_unused_parameters` is set.
217 if self.find_unused_parameters:
218 self.reducer.prepare_for_backward(list(_find_tensors(output)))
219 else:
220 self.reducer.prepare_for_backward([])
221 return output
222
223
224class DP(DataParallel):

Callers

nothing calls this directly

Calls 5

parallel_applyMethod · 0.95
_find_tensorsFunction · 0.85
training_stepMethod · 0.80
test_stepMethod · 0.45
validation_stepMethod · 0.45

Tested by

no test coverage detected