(i, module, input, kwargs, device=None)
| 107 | grad_enabled = torch.is_grad_enabled() |
| 108 | |
| 109 | def _worker(i, module, input, kwargs, device=None): |
| 110 | torch.set_grad_enabled(grad_enabled) |
| 111 | if device is None: |
| 112 | device = get_a_var(input).get_device() |
| 113 | try: |
| 114 | with torch.cuda.device(device): |
| 115 | # this also avoids accidental slicing of `input` if it is a Tensor |
| 116 | if not isinstance(input, (list, tuple)): |
| 117 | input = (input,) |
| 118 | |
| 119 | # --------------- |
| 120 | # CHANGE |
| 121 | if module.training: |
| 122 | output = module.training_step(*input, **kwargs) |
| 123 | |
| 124 | elif module.testing: |
| 125 | output = module.test_step(*input, **kwargs) |
| 126 | |
| 127 | else: |
| 128 | output = module.validation_step(*input, **kwargs) |
| 129 | # --------------- |
| 130 | |
| 131 | with lock: |
| 132 | results[i] = output |
| 133 | except Exception as e: |
| 134 | with lock: |
| 135 | results[i] = e |
| 136 | |
| 137 | # make sure each module knows what training state it's in... |
| 138 | # fixes weird bug where copies are out of sync |
no test coverage detected