MCPcopy Index your code
hub / github.com/MoonInTheRiver/DiffSinger / parallel_apply

Function parallel_apply

utils/pl_utils.py:80–163  ·  view source on GitHub ↗

r"""Applies each `module` in :attr:`modules` in parallel on arguments contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) on each of :attr:`devices`. Args: modules (Module): modules to be parallelized inputs (tensor): inputs to the modules d

(modules, inputs, kwargs_tup=None, devices=None)

Source from the content-addressed store, hash-verified

78
79
80def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover
81 r"""Applies each `module` in :attr:`modules` in parallel on arguments
82 contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
83 on each of :attr:`devices`.
84
85 Args:
86 modules (Module): modules to be parallelized
87 inputs (tensor): inputs to the modules
88 devices (list of int or torch.device): CUDA devices
89
90 :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
91 :attr:`devices` (if given) should all have same length. Moreover, each
92 element of :attr:`inputs` can either be a single object as the only argument
93 to a module, or a collection of positional arguments.
94 """
95 assert len(modules) == len(inputs)
96 if kwargs_tup is not None:
97 assert len(modules) == len(kwargs_tup)
98 else:
99 kwargs_tup = ({},) * len(modules)
100 if devices is not None:
101 assert len(modules) == len(devices)
102 else:
103 devices = [None] * len(modules)
104 devices = list(map(lambda x: _get_device_index(x, True), devices))
105 lock = threading.Lock()
106 results = {}
107 grad_enabled = torch.is_grad_enabled()
108
109 def _worker(i, module, input, kwargs, device=None):
110 torch.set_grad_enabled(grad_enabled)
111 if device is None:
112 device = get_a_var(input).get_device()
113 try:
114 with torch.cuda.device(device):
115 # this also avoids accidental slicing of `input` if it is a Tensor
116 if not isinstance(input, (list, tuple)):
117 input = (input,)
118
119 # ---------------
120 # CHANGE
121 if module.training:
122 output = module.training_step(*input, **kwargs)
123
124 elif module.testing:
125 output = module.test_step(*input, **kwargs)
126
127 else:
128 output = module.validation_step(*input, **kwargs)
129 # ---------------
130
131 with lock:
132 results[i] = output
133 except Exception as e:
134 with lock:
135 results[i] = e
136
137 # make sure each module knows what training state it's in...

Callers 2

parallel_applyMethod · 0.85
parallel_applyMethod · 0.85

Calls 2

_workerFunction · 0.85
startMethod · 0.80

Tested by

no test coverage detected