MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_adam_kernel

Function check_adam_kernel

tests/test_optimizer/test_adam_kernel.py:117–147  ·  view source on GitHub ↗
(
    kernel: Type[AdamKernel],
    adamw: bool,
    weight_decay: float,
    p_dtype: torch.dtype,
    g_dtype: torch.dtype,
    device: torch.device,
    n_steps: int,
    rtol: float,
    atol: float,
)

Source from the content-addressed store, hash-verified

115
116
117def check_adam_kernel(
118 kernel: Type[AdamKernel],
119 adamw: bool,
120 weight_decay: float,
121 p_dtype: torch.dtype,
122 g_dtype: torch.dtype,
123 device: torch.device,
124 n_steps: int,
125 rtol: float,
126 atol: float,
127):
128 lr = 1e-3
129 beta1, beta2 = 0.9, 0.999
130 eps = 1e-8
131 torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw)
132 adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw)
133 master_p = torch.rand(64, device=device)
134 master_g = torch.rand_like(master_p)
135 master_exp_avg = torch.zeros_like(master_p)
136 master_exp_avg_sq = torch.zeros_like(master_p)
137 p = master_p.clone().to(p_dtype)
138 g = master_g.clone().to(g_dtype)
139 exp_avg = master_exp_avg.clone().to(p_dtype)
140 exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
141
142 for step in range(1, 1 + n_steps):
143 torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
144 adam_kernel.update(step, p, g, exp_avg, exp_avg_sq)
145 # if overflow, the weight won't be updated. so there will be no nan in p
146 assert not torch.isnan(p).any()
147 assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)
148
149
150@pytest.mark.parametrize("adamw", [False, True])

Callers 2

test_fused_adam_kernelFunction · 0.85
test_cpu_adam_kernelFunction · 0.85

Calls 5

updateMethod · 0.95
TorchAdamKernelClass · 0.85
toMethod · 0.45
cloneMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…