MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / forward

Method forward

utils/commons/ddp_utils.py:23–150  ·  view source on GitHub ↗
(self, *inputs, **kwargs)

Source from the content-addressed store, hash-verified

21 """
22
23 def forward(self, *inputs, **kwargs): # pragma: no cover
24 torch_version = get_torch_version()
25 if version.parse(torch_version) < version.parse("1.11"):
26 self._sync_params()
27 inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
28 assert len(self.device_ids) == 1
29 if self.module.training:
30 output = self.module.training_step(*inputs[0], **kwargs[0])
31 elif self.module.testing:
32 output = self.module.test_step(*inputs[0], **kwargs[0])
33 else:
34 output = self.module.validation_step(*inputs[0], **kwargs[0])
35 if torch.is_grad_enabled():
36 # We'll return the output object verbatim since it is a freeform
37 # object. We need to find any tensors in this object, though,
38 # because we need to figure out which parameters were used during
39 # this forward pass, to ensure we short circuit reduction for any
40 # unused parameters. Only if `find_unused_parameters` is set.
41 if self.find_unused_parameters:
42 self.reducer.prepare_for_backward(list(_find_tensors(output)))
43 else:
44 self.reducer.prepare_for_backward([])
45 elif version.parse(torch_version) < version.parse("2.1"):
46 from torch.nn.parallel.distributed import \
47 logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
48 with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
49 if torch.is_grad_enabled() and self.require_backward_grad_sync:
50 self.logger.set_runtime_stats_and_log()
51 self.num_iterations += 1
52 self.reducer.prepare_for_forward()
53
54 # Notify the join context that this process has not joined, if
55 # needed
56 work = Join.notify_join_context(self)
57 if work:
58 self.reducer._set_forward_pass_work_handle(
59 work, self._divide_by_initial_world_size
60 )
61
62 # Calling _rebuild_buckets before forward compuation,
63 # It may allocate new buckets before deallocating old buckets
64 # inside _rebuild_buckets. To save peak memory usage,
65 # call _rebuild_buckets before the peak memory usage increases
66 # during forward computation.
67 # This should be called only once during whole training period.
68 if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
69 logging.info("Reducer buckets have been rebuilt in this iteration.")
70 self._has_rebuilt_buckets = True
71
72 # sync params according to location (before/after forward) user
73 # specified as part of hook, if hook was specified.
74 buffer_hook_registered = hasattr(self, 'buffer_hook')
75 if self._check_sync_bufs_pre_fwd():
76 self._sync_buffers()
77
78 if self._join_config.enable:
79 # Notify joined ranks whether they should sync in backwards pass or not.
80 self._check_global_requires_backward_grad_sync(is_joined_rank=False)

Callers

nothing calls this directly

Calls 6

get_torch_versionFunction · 0.85
parseMethod · 0.80
training_stepMethod · 0.80
applyMethod · 0.80
test_stepMethod · 0.45
validation_stepMethod · 0.45

Tested by

no test coverage detected