MCPcopy
hub / github.com/hpcaitech/ColossalAI / test_adam_optim_on_bert

Function test_adam_optim_on_bert

tests/test_optimizer/test_adam_optim.py:40–72  ·  view source on GitHub ↗
(
    optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]],
    device: torch.device,
    adamw: bool,
    p_dtype: torch.dtype,
    g_dtype: torch.dtype,
)

Source from the content-addressed store, hash-verified

38@pytest.mark.parametrize("adamw", [False, True])
39@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES)
40def test_adam_optim_on_bert(
41 optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]],
42 device: torch.device,
43 adamw: bool,
44 p_dtype: torch.dtype,
45 g_dtype: torch.dtype,
46) -> None:
47 model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values()))
48 torch_model = model_fn().to(device)
49 model = deepcopy(torch_model).to(p_dtype)
50 lr = 1e-3
51 beta1, beta2 = 0.9, 0.999
52 eps = 1e-8
53 torch_optim_cls = AdamW if adamw else Adam
54 torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
55 optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
56
57 rtol, atol = 1e-5, 1e-5
58 if p_dtype is torch.float16 or g_dtype is torch.float16:
59 rtol, atol = 2e-3, 2e-3
60 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
61 rtol, atol = 4e-3, 4e-3
62
63 for _ in range(N_STEPS):
64 set_grad(model, torch_model, g_dtype)
65 torch_optim.step()
66 optim.step()
67 torch_optim.zero_grad()
68 optim.zero_grad()
69 for p, torch_p in zip(model.parameters(), torch_model.parameters()):
70 # if overflow, the weight won't be updated. so there will be no nan in p
71 assert not torch.isnan(p).any()
72 assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)

Callers

nothing calls this directly

Calls 9

setup_param_groupsFunction · 0.90
set_gradFunction · 0.85
valuesMethod · 0.80
get_sub_registryMethod · 0.80
model_fnFunction · 0.50
toMethod · 0.45
stepMethod · 0.45
zero_gradMethod · 0.45
parametersMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…