Return a copy of the whole optimizer states stored in CPU memory. If this is a multi-processing instance, the states will be returned in shared memory. If the underlying embedding is currently stored on multiple GPUs, all processes must call this method in the same order.
(self, **kwargs)
| 453 | self._clean_grad = True |
| 454 | |
| 455 | def state_dict(self, **kwargs): # pylint: disable=unused-argument |
| 456 | """Return a copy of the whole optimizer states stored in CPU memory. |
| 457 | If this is a multi-processing instance, the states will be returned in |
| 458 | shared memory. If the underlying embedding is currently stored on |
| 459 | multiple GPUs, all processes must call this method in the same order. |
| 460 | |
| 461 | NOTE: This method must be called by all processes sharing the |
| 462 | underlying embedding, or it may result in a deadlock. |
| 463 | |
| 464 | Returns |
| 465 | ------- |
| 466 | dictionary of optimizer states |
| 467 | The optimizer states stored in CPU memory. |
| 468 | """ |
| 469 | return { |
| 470 | "state": { |
| 471 | emb.name: emb._all_get_optm_state() for emb in self._params |
| 472 | }, |
| 473 | "param_groups": self.param_groups, |
| 474 | } |
| 475 | |
| 476 | def load_state_dict( |
| 477 | self, state_dict, **kwargs |