(self)
| 200 | self.update(idx, grad, emb) |
| 201 | |
| 202 | def _shared_step(self): |
| 203 | with th.no_grad(): |
| 204 | # Frequently alloc and free shared memory to hold intermediate tensor is expensive |
| 205 | # We cache shared memory buffers in shared_emb. |
| 206 | shared_emb = {emb.name: ([], []) for emb in self._params} |
| 207 | |
| 208 | # Go through all sparse embeddings |
| 209 | for emb in self._params: # pylint: disable=too-many-nested-blocks |
| 210 | emb_name = emb.name |
| 211 | |
| 212 | # we need to combine gradients from multiple forward paths |
| 213 | idx = [] |
| 214 | grad = [] |
| 215 | for i, data in emb._trace: |
| 216 | idx.append(i) |
| 217 | grad.append(data.grad.data) |
| 218 | # If the sparse embedding is not used in the previous forward step |
| 219 | # The idx and grad will be empty, initialize them as empty tensors to |
| 220 | # avoid crashing the optimizer step logic. |
| 221 | # |
| 222 | # Note: we cannot skip the gradient exchange and update steps as other |
| 223 | # working processes may send gradient update requests corresponding |
| 224 | # to certain embedding to this process. |
| 225 | idx = ( |
| 226 | th.cat(idx, dim=0) |
| 227 | if len(idx) != 0 |
| 228 | else th.zeros((0,), dtype=th.long, device=th.device("cpu")) |
| 229 | ) |
| 230 | grad = ( |
| 231 | th.cat(grad, dim=0) |
| 232 | if len(grad) != 0 |
| 233 | else th.zeros( |
| 234 | (0, emb.embedding_dim), |
| 235 | dtype=th.float32, |
| 236 | device=th.device("cpu"), |
| 237 | ) |
| 238 | ) |
| 239 | |
| 240 | device = grad.device |
| 241 | idx_dtype = idx.dtype |
| 242 | grad_dtype = grad.dtype |
| 243 | grad_dim = grad.shape[1] |
| 244 | if self._world_size > 1: |
| 245 | if emb_name not in self._shared_cache: |
| 246 | self._shared_cache[emb_name] = {} |
| 247 | |
| 248 | # Each training process takes the resposibility of updating a range |
| 249 | # of node embeddings, thus we can parallel the gradient update. |
| 250 | # The overall progress includes: |
| 251 | # 1. In each training process: |
| 252 | # 1.a Deciding which process a node embedding belongs to according |
| 253 | # to the formula: process_id = node_idx mod num_of_process(N) |
| 254 | # 1.b Split the node index tensor and gradient tensor into N parts |
| 255 | # according to step 1. |
| 256 | # 1.c Write each node index sub-tensor and gradient sub-tensor into |
| 257 | # different DGL shared memory buffers. |
| 258 | # 2. Cross training process synchronization |
| 259 | # 3. In each traning process: |
no test coverage detected