MCPcopy
hub / github.com/hpcaitech/ColossalAI / test_fused_adam_kernel

Function test_fused_adam_kernel

tests/test_optimizer/test_adam_kernel.py:153–161  ·  view source on GitHub ↗
(adamw, weight_decay, p_dtype, g_dtype)

Source from the content-addressed store, hash-verified

151@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
152@pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES)
153def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
154 rtol, atol = 1e-5, 1e-8
155 if p_dtype is torch.float16 or g_dtype is torch.float16:
156 rtol, atol = 1e-3, 1e-3
157 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
158 rtol, atol = 4e-3, 4e-3
159 check_adam_kernel(
160 FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol
161 )
162
163
164@pytest.mark.parametrize("adamw", [False, True])

Callers

nothing calls this directly

Calls 3

get_acceleratorFunction · 0.90
check_adam_kernelFunction · 0.85
get_current_deviceMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…