MCPcopy
hub / github.com/hpcaitech/ColossalAI / test_cpu_adam_kernel

Function test_cpu_adam_kernel

tests/test_optimizer/test_adam_kernel.py:167–171  ·  view source on GitHub ↗
(adamw, weight_decay, p_dtype, g_dtype)

Source from the content-addressed store, hash-verified

165@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
166@pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES)
167def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
168 rtol, atol = 1e-5, 1e-8
169 if p_dtype is torch.float16 or g_dtype is torch.float16:
170 rtol, atol = 1e-3, 1e-3
171 check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol)

Callers

nothing calls this directly

Calls 2

check_adam_kernelFunction · 0.85
deviceMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…