| 354 | @pytest.mark.parametrize("backend", ["nccl", "gloo"]) |
| 355 | @pytest.mark.parametrize("zero_comm", [True, False]) |
| 356 | def test_multiprocess_sparse_adam(num_workers, backend, zero_comm): |
| 357 | if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers: |
| 358 | pytest.skip("Not enough GPUs to run test.") |
| 359 | |
| 360 | worker_list = [] |
| 361 | num_embs = 128 |
| 362 | emb_dim = 10 |
| 363 | dgl_weight = th.empty((num_embs, emb_dim)) |
| 364 | ctx = mp.get_context("spawn") |
| 365 | for i in range(num_workers): |
| 366 | device = F.ctx() |
| 367 | if device.type == "cuda": |
| 368 | # make sure each process has a unique GPU |
| 369 | device = th.device(i) |
| 370 | p = ctx.Process( |
| 371 | target=start_sparse_adam_worker, |
| 372 | args=( |
| 373 | i, |
| 374 | device, |
| 375 | num_workers, |
| 376 | dgl_weight, |
| 377 | th.device("cpu"), |
| 378 | True, |
| 379 | backend, |
| 380 | num_embs, |
| 381 | emb_dim, |
| 382 | zero_comm, |
| 383 | ), |
| 384 | ) |
| 385 | p.start() |
| 386 | worker_list.append(p) |
| 387 | for p in worker_list: |
| 388 | p.join() |
| 389 | |
| 390 | worker_list = [] |
| 391 | torch_weight = th.empty((num_embs, emb_dim)) |
| 392 | for i in range(num_workers): |
| 393 | p = ctx.Process( |
| 394 | target=start_torch_adam_worker, |
| 395 | args=( |
| 396 | i, |
| 397 | num_workers, |
| 398 | torch_weight, |
| 399 | False, |
| 400 | num_embs, |
| 401 | emb_dim, |
| 402 | zero_comm, |
| 403 | ), |
| 404 | ) |
| 405 | p.start() |
| 406 | worker_list.append(p) |
| 407 | for p in worker_list: |
| 408 | p.join() |
| 409 | |
| 410 | assert F.allclose(dgl_weight, torch_weight) |
| 411 | |
| 412 | |
| 413 | @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") |