(
optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]],
device: torch.device,
adamw: bool,
p_dtype: torch.dtype,
g_dtype: torch.dtype,
)
| 38 | @pytest.mark.parametrize("adamw", [False, True]) |
| 39 | @pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES) |
| 40 | def test_adam_optim_on_bert( |
| 41 | optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], |
| 42 | device: torch.device, |
| 43 | adamw: bool, |
| 44 | p_dtype: torch.dtype, |
| 45 | g_dtype: torch.dtype, |
| 46 | ) -> None: |
| 47 | model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values())) |
| 48 | torch_model = model_fn().to(device) |
| 49 | model = deepcopy(torch_model).to(p_dtype) |
| 50 | lr = 1e-3 |
| 51 | beta1, beta2 = 0.9, 0.999 |
| 52 | eps = 1e-8 |
| 53 | torch_optim_cls = AdamW if adamw else Adam |
| 54 | torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps) |
| 55 | optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw) |
| 56 | |
| 57 | rtol, atol = 1e-5, 1e-5 |
| 58 | if p_dtype is torch.float16 or g_dtype is torch.float16: |
| 59 | rtol, atol = 2e-3, 2e-3 |
| 60 | if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: |
| 61 | rtol, atol = 4e-3, 4e-3 |
| 62 | |
| 63 | for _ in range(N_STEPS): |
| 64 | set_grad(model, torch_model, g_dtype) |
| 65 | torch_optim.step() |
| 66 | optim.step() |
| 67 | torch_optim.zero_grad() |
| 68 | optim.zero_grad() |
| 69 | for p, torch_p in zip(model.parameters(), torch_model.parameters()): |
| 70 | # if overflow, the weight won't be updated. so there will be no nan in p |
| 71 | assert not torch.isnan(p).any() |
| 72 | assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol) |
nothing calls this directly
no test coverage detected
searching dependent graphs…