(adamw, weight_decay, p_dtype, g_dtype)
| 151 | @pytest.mark.parametrize("weight_decay", [0.0, 0.1]) |
| 152 | @pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES) |
| 153 | def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): |
| 154 | rtol, atol = 1e-5, 1e-8 |
| 155 | if p_dtype is torch.float16 or g_dtype is torch.float16: |
| 156 | rtol, atol = 1e-3, 1e-3 |
| 157 | if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: |
| 158 | rtol, atol = 4e-3, 4e-3 |
| 159 | check_adam_kernel( |
| 160 | FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol |
| 161 | ) |
| 162 | |
| 163 | |
| 164 | @pytest.mark.parametrize("adamw", [False, True]) |
nothing calls this directly
no test coverage detected
searching dependent graphs…