(adamw, weight_decay, p_dtype, g_dtype)
| 165 | @pytest.mark.parametrize("weight_decay", [0.0, 0.1]) |
| 166 | @pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES) |
| 167 | def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): |
| 168 | rtol, atol = 1e-5, 1e-8 |
| 169 | if p_dtype is torch.float16 or g_dtype is torch.float16: |
| 170 | rtol, atol = 1e-3, 1e-3 |
| 171 | check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol) |
nothing calls this directly
no test coverage detected
searching dependent graphs…