MCPcopy Index your code
hub / github.com/MoonInTheRiver/DiffSinger / _worker

Function _worker

utils/pl_utils.py:109–135  ·  view source on GitHub ↗
(i, module, input, kwargs, device=None)

Source from the content-addressed store, hash-verified

107 grad_enabled = torch.is_grad_enabled()
108
109 def _worker(i, module, input, kwargs, device=None):
110 torch.set_grad_enabled(grad_enabled)
111 if device is None:
112 device = get_a_var(input).get_device()
113 try:
114 with torch.cuda.device(device):
115 # this also avoids accidental slicing of `input` if it is a Tensor
116 if not isinstance(input, (list, tuple)):
117 input = (input,)
118
119 # ---------------
120 # CHANGE
121 if module.training:
122 output = module.training_step(*input, **kwargs)
123
124 elif module.testing:
125 output = module.test_step(*input, **kwargs)
126
127 else:
128 output = module.validation_step(*input, **kwargs)
129 # ---------------
130
131 with lock:
132 results[i] = output
133 except Exception as e:
134 with lock:
135 results[i] = e
136
137 # make sure each module knows what training state it's in...
138 # fixes weird bug where copies are out of sync

Callers 1

parallel_applyFunction · 0.85

Calls 4

get_a_varFunction · 0.85
training_stepMethod · 0.80
test_stepMethod · 0.45
validation_stepMethod · 0.45

Tested by

no test coverage detected