The step function. The step function is invoked at the end of every batch to update embeddings
(self)
| 65 | assert not self._world_size is None |
| 66 | |
| 67 | def step(self): |
| 68 | """The step function. |
| 69 | |
| 70 | The step function is invoked at the end of every batch to update embeddings |
| 71 | """ |
| 72 | # on the first step, check to see if the grads are on the GPU |
| 73 | if self._first_step: |
| 74 | for emb in self._params: |
| 75 | for _, data in emb._trace: |
| 76 | if data.grad.device.type == "cuda": |
| 77 | # create a communicator |
| 78 | if self._device: |
| 79 | assert ( |
| 80 | self._device == data.grad.device |
| 81 | ), "All gradients must be on the same device" |
| 82 | else: |
| 83 | self._device = data.grad.device |
| 84 | else: |
| 85 | assert ( |
| 86 | not self._device |
| 87 | ), "All gradients must be on the same device" |
| 88 | |
| 89 | # distributed backend use nccl |
| 90 | if self._device and ( |
| 91 | not th.distributed.is_initialized() |
| 92 | or th.distributed.get_backend() == "nccl" |
| 93 | ): |
| 94 | # device is only set if the grads are on a GPU |
| 95 | self._comm_setup() |
| 96 | else: |
| 97 | self._shared_setup() |
| 98 | self._first_step = False |
| 99 | |
| 100 | if self._comm: |
| 101 | self._comm_step() |
| 102 | else: |
| 103 | self._shared_step() |
| 104 | |
| 105 | @abstractmethod |
| 106 | def setup(self, params): |
no test coverage detected