Method
__init__
(
self,
params,
lr,
betas=(0.9, 0.999),
eps=1e-08,
use_uva=None,
dtype=th.float32,
)
Source from the content-addressed store, hash-verified
| 701 | """ |
| 702 | |
| 703 | def __init__( |
| 704 | self, |
| 705 | params, |
| 706 | lr, |
| 707 | betas=(0.9, 0.999), |
| 708 | eps=1e-08, |
| 709 | use_uva=None, |
| 710 | dtype=th.float32, |
| 711 | ): |
| 712 | super(SparseAdam, self).__init__(params, lr) |
| 713 | self._lr = lr |
| 714 | self._beta1 = betas[0] |
| 715 | self._beta2 = betas[1] |
| 716 | self._eps = eps |
| 717 | self._use_uva = use_uva |
| 718 | self._nd_handle = {} |
| 719 | self._is_using_uva = {} |
| 720 | assert dtype in [th.float16, th.float32], ( |
| 721 | "Unsupported dtype {}. Valid choices are th.float32 " |
| 722 | "and th.float32".format(dtype) |
| 723 | ) |
| 724 | self._dtype = dtype |
| 725 | |
| 726 | # setup tensors for optimizer states |
| 727 | self.setup(self._params) |
| 728 | |
| 729 | def _setup_uva(self, name, mem, power): |
| 730 | self._is_using_uva[name] = True |
Tested by
no test coverage detected