| 308 | @unittest.skipIf(F.ctx().type != "cpu", reason="cpu only test") |
| 309 | @pytest.mark.parametrize("num_workers", [2, 4]) |
| 310 | def test_multiprocess_cpu_sparse_adam(num_workers): |
| 311 | backend = "gloo" |
| 312 | worker_list = [] |
| 313 | num_embs = 128 |
| 314 | emb_dim = 10 |
| 315 | dgl_weight = th.empty((num_embs, emb_dim)) |
| 316 | ctx = mp.get_context("spawn") |
| 317 | for i in range(num_workers): |
| 318 | device = F.ctx() |
| 319 | p = ctx.Process( |
| 320 | target=start_sparse_adam_worker, |
| 321 | args=( |
| 322 | i, |
| 323 | device, |
| 324 | num_workers, |
| 325 | dgl_weight, |
| 326 | th.device("cpu"), |
| 327 | True, |
| 328 | backend, |
| 329 | ), |
| 330 | ) |
| 331 | p.start() |
| 332 | worker_list.append(p) |
| 333 | for p in worker_list: |
| 334 | p.join() |
| 335 | |
| 336 | worker_list = [] |
| 337 | torch_weight = th.empty((num_embs, emb_dim)) |
| 338 | for i in range(num_workers): |
| 339 | p = ctx.Process( |
| 340 | target=start_torch_adam_worker, |
| 341 | args=(i, num_workers, torch_weight, False), |
| 342 | ) |
| 343 | p.start() |
| 344 | worker_list.append(p) |
| 345 | for p in worker_list: |
| 346 | p.join() |
| 347 | |
| 348 | assert F.allclose(dgl_weight, torch_weight) |
| 349 | |
| 350 | |
| 351 | @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") |