(
rank,
device,
world_size,
weight,
tensor_dev="cpu",
has_zero_grad=False,
backend="gloo",
num_embs=128,
emb_dim=10,
zero_comm=True,
)
| 176 | |
| 177 | |
| 178 | def start_sparse_adam_worker( |
| 179 | rank, |
| 180 | device, |
| 181 | world_size, |
| 182 | weight, |
| 183 | tensor_dev="cpu", |
| 184 | has_zero_grad=False, |
| 185 | backend="gloo", |
| 186 | num_embs=128, |
| 187 | emb_dim=10, |
| 188 | zero_comm=True, |
| 189 | ): |
| 190 | print("start sparse worker for adam {}".format(rank)) |
| 191 | dist_init_method = "tcp://{master_ip}:{master_port}".format( |
| 192 | master_ip="127.0.0.1", master_port="12345" |
| 193 | ) |
| 194 | |
| 195 | if device.type == "cuda": |
| 196 | th.cuda.set_device(device) |
| 197 | |
| 198 | th.distributed.init_process_group( |
| 199 | backend=backend, |
| 200 | init_method=dist_init_method, |
| 201 | world_size=world_size, |
| 202 | rank=rank, |
| 203 | ) |
| 204 | |
| 205 | init_weight = th.empty((num_embs, emb_dim)) |
| 206 | th.manual_seed(0) |
| 207 | th.nn.init.uniform_(init_weight, -1.0, 1.0) |
| 208 | dgl_emb = NodeEmbedding( |
| 209 | num_embs, emb_dim, "test", init_func=initializer, device=tensor_dev |
| 210 | ) |
| 211 | dgl_emb.all_set_embedding(init_weight) |
| 212 | |
| 213 | if has_zero_grad: |
| 214 | dgl_emb_zero = NodeEmbedding( |
| 215 | num_embs, emb_dim, "zero", init_func=initializer, device=tensor_dev |
| 216 | ) |
| 217 | dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01) |
| 218 | else: |
| 219 | dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01) |
| 220 | |
| 221 | th.manual_seed(rank) |
| 222 | if zero_comm: |
| 223 | start = (num_embs // world_size) * rank |
| 224 | end = (num_embs // world_size) * (rank + 1) |
| 225 | idx = th.randint(start, end, size=(4,)).to(tensor_dev) |
| 226 | else: |
| 227 | idx = th.randint(0, num_embs, size=(4,)).to(tensor_dev) |
| 228 | dgl_value = dgl_emb(idx, device) |
| 229 | labels = th.ones((4,)).long().to(device) |
| 230 | dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels) |
| 231 | dgl_adam.zero_grad() |
| 232 | dgl_loss.backward() |
| 233 | dgl_adam.step() |
| 234 | th.distributed.barrier() |
| 235 | dgl_weight = dgl_emb.all_get_embedding().detach() |
nothing calls this directly
no test coverage detected