(
kernel: Type[AdamKernel],
adamw: bool,
weight_decay: float,
p_dtype: torch.dtype,
g_dtype: torch.dtype,
device: torch.device,
n_steps: int,
rtol: float,
atol: float,
)
| 115 | |
| 116 | |
| 117 | def check_adam_kernel( |
| 118 | kernel: Type[AdamKernel], |
| 119 | adamw: bool, |
| 120 | weight_decay: float, |
| 121 | p_dtype: torch.dtype, |
| 122 | g_dtype: torch.dtype, |
| 123 | device: torch.device, |
| 124 | n_steps: int, |
| 125 | rtol: float, |
| 126 | atol: float, |
| 127 | ): |
| 128 | lr = 1e-3 |
| 129 | beta1, beta2 = 0.9, 0.999 |
| 130 | eps = 1e-8 |
| 131 | torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw) |
| 132 | adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw) |
| 133 | master_p = torch.rand(64, device=device) |
| 134 | master_g = torch.rand_like(master_p) |
| 135 | master_exp_avg = torch.zeros_like(master_p) |
| 136 | master_exp_avg_sq = torch.zeros_like(master_p) |
| 137 | p = master_p.clone().to(p_dtype) |
| 138 | g = master_g.clone().to(g_dtype) |
| 139 | exp_avg = master_exp_avg.clone().to(p_dtype) |
| 140 | exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype) |
| 141 | |
| 142 | for step in range(1, 1 + n_steps): |
| 143 | torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) |
| 144 | adam_kernel.update(step, p, g, exp_avg, exp_avg_sq) |
| 145 | # if overflow, the weight won't be updated. so there will be no nan in p |
| 146 | assert not torch.isnan(p).any() |
| 147 | assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol) |
| 148 | |
| 149 | |
| 150 | @pytest.mark.parametrize("adamw", [False, True]) |
no test coverage detected
searching dependent graphs…