MCPcopy
hub / github.com/hpcaitech/ColossalAI / run_dist_lamb_basic

Function run_dist_lamb_basic

tests/test_optimizer/test_dist_lamb.py:89–152  ·  view source on GitHub ↗

Test without forward

(
    bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
)

Source from the content-addressed store, hash-verified

87@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
88@clear_cache_before_run()
89def run_dist_lamb_basic(
90 bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
91) -> None:
92 """Test without forward"""
93 p_dtype, g_dtype = p_g_dtype
94 tp_size, zero_size = tp_zero_size
95
96 # Set distributed groups
97 rank = dist.get_rank()
98 clear_layout_converter() # Ensure correct sharding
99 proc_mesh = ProcessGroupMesh(tp_size, zero_size)
100 tp_group = proc_mesh.get_group_along_axis(0)
101
102 tp_rank = dist.get_rank(tp_group)
103 seed_all(_SEED) # Fix model init
104 torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True).to(rank)
105 tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group).to(rank)
106 # Ensure equal weight init
107 assert_close(
108 torch_model.fc1.weight[tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
109 tp_model.fc1.weight,
110 )
111 assert_close(
112 torch_model.fc2.weight[:, tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
113 tp_model.fc2.weight,
114 )
115
116 # Set up optimizers
117 lr = 1e-3
118 beta1, beta2 = 0.9, 0.999
119 eps = 1e-8
120 torch_optim = Lamb(
121 setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps, bias_correction=bias_correction
122 )
123 optim = DistributedLamb(
124 setup_param_groups(tp_model),
125 lr=lr,
126 betas=(beta1, beta2),
127 eps=eps,
128 bias_correction=bias_correction,
129 )
130 optim.setup_distributed(tp_group)
131
132 rtol, atol = 8e-7, 8e-7
133 if p_dtype is torch.float16 or g_dtype is torch.float16:
134 rtol, atol = 1e-6, 1e-6
135 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
136 rtol, atol = 2e-6, 2e-6
137
138 for i in range(_N_STEP):
139 seed_all(_SEED + i) # NOTE: having only one manual_seed above doesn't work?
140 set_dist_grad(tp_model, torch_model, g_dtype, tp_group)
141
142 torch_optim.step()
143 optim.step()
144 torch_optim.zero_grad()
145 optim.zero_grad()
146 try:

Callers 1

check_dist_lambFunction · 0.85

Calls 15

get_group_along_axisMethod · 0.95
setup_distributedMethod · 0.95
stepMethod · 0.95
stepMethod · 0.95
clear_layout_converterFunction · 0.90
ProcessGroupMeshClass · 0.90
seed_allFunction · 0.90
LambClass · 0.90
setup_param_groupsFunction · 0.90
DistributedLambClass · 0.90
TPNetClass · 0.85
get_rankMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…