Test without forward
(
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
)
| 87 | @parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) |
| 88 | @clear_cache_before_run() |
| 89 | def run_dist_lamb_basic( |
| 90 | bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int] |
| 91 | ) -> None: |
| 92 | """Test without forward""" |
| 93 | p_dtype, g_dtype = p_g_dtype |
| 94 | tp_size, zero_size = tp_zero_size |
| 95 | |
| 96 | # Set distributed groups |
| 97 | rank = dist.get_rank() |
| 98 | clear_layout_converter() # Ensure correct sharding |
| 99 | proc_mesh = ProcessGroupMesh(tp_size, zero_size) |
| 100 | tp_group = proc_mesh.get_group_along_axis(0) |
| 101 | |
| 102 | tp_rank = dist.get_rank(tp_group) |
| 103 | seed_all(_SEED) # Fix model init |
| 104 | torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True).to(rank) |
| 105 | tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group).to(rank) |
| 106 | # Ensure equal weight init |
| 107 | assert_close( |
| 108 | torch_model.fc1.weight[tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size], |
| 109 | tp_model.fc1.weight, |
| 110 | ) |
| 111 | assert_close( |
| 112 | torch_model.fc2.weight[:, tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size], |
| 113 | tp_model.fc2.weight, |
| 114 | ) |
| 115 | |
| 116 | # Set up optimizers |
| 117 | lr = 1e-3 |
| 118 | beta1, beta2 = 0.9, 0.999 |
| 119 | eps = 1e-8 |
| 120 | torch_optim = Lamb( |
| 121 | setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps, bias_correction=bias_correction |
| 122 | ) |
| 123 | optim = DistributedLamb( |
| 124 | setup_param_groups(tp_model), |
| 125 | lr=lr, |
| 126 | betas=(beta1, beta2), |
| 127 | eps=eps, |
| 128 | bias_correction=bias_correction, |
| 129 | ) |
| 130 | optim.setup_distributed(tp_group) |
| 131 | |
| 132 | rtol, atol = 8e-7, 8e-7 |
| 133 | if p_dtype is torch.float16 or g_dtype is torch.float16: |
| 134 | rtol, atol = 1e-6, 1e-6 |
| 135 | if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: |
| 136 | rtol, atol = 2e-6, 2e-6 |
| 137 | |
| 138 | for i in range(_N_STEP): |
| 139 | seed_all(_SEED + i) # NOTE: having only one manual_seed above doesn't work? |
| 140 | set_dist_grad(tp_model, torch_model, g_dtype, tp_group) |
| 141 | |
| 142 | torch_optim.step() |
| 143 | optim.step() |
| 144 | torch_optim.zero_grad() |
| 145 | optim.zero_grad() |
| 146 | try: |
no test coverage detected
searching dependent graphs…