(idtype, sampler_name, mode, use_ddp)
| 328 | ) |
| 329 | @pytest.mark.parametrize("use_ddp", [False, True]) |
| 330 | def test_node_dataloader(idtype, sampler_name, mode, use_ddp): |
| 331 | if mode != "cpu" and F.ctx() == F.cpu(): |
| 332 | pytest.skip("UVA and GPU sampling require a GPU.") |
| 333 | if use_ddp: |
| 334 | if os.name == "nt": |
| 335 | pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...") |
| 336 | dist.init_process_group( |
| 337 | "gloo" if F.ctx() == F.cpu() else "nccl", |
| 338 | "tcp://127.0.0.1:12347", |
| 339 | world_size=1, |
| 340 | rank=0, |
| 341 | ) |
| 342 | g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype) |
| 343 | g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu()) |
| 344 | g1.ndata["label"] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu()) |
| 345 | if mode in ("cpu", "uva_cpu_indices"): |
| 346 | indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu()) |
| 347 | else: |
| 348 | indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda()) |
| 349 | if mode == "pure_gpu": |
| 350 | g1 = g1.to(F.cuda()) |
| 351 | |
| 352 | use_uva = mode.startswith("uva") |
| 353 | |
| 354 | sampler = { |
| 355 | "full": dgl.dataloading.MultiLayerFullNeighborSampler(2), |
| 356 | "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]), |
| 357 | "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]), |
| 358 | "labor": dgl.dataloading.LaborSampler([3, 3]), |
| 359 | }[sampler_name] |
| 360 | for num_workers in [0, 1, 2] if mode == "cpu" else [0]: |
| 361 | dataloader = dgl.dataloading.DataLoader( |
| 362 | g1, |
| 363 | indices, |
| 364 | sampler, |
| 365 | device=F.ctx(), |
| 366 | batch_size=g1.num_nodes(), |
| 367 | num_workers=num_workers, |
| 368 | use_uva=use_uva, |
| 369 | use_ddp=use_ddp, |
| 370 | ) |
| 371 | for input_nodes, output_nodes, blocks in dataloader: |
| 372 | _check_device(input_nodes) |
| 373 | _check_device(output_nodes) |
| 374 | _check_device(blocks) |
| 375 | _check_dtype(input_nodes, idtype, "dtype") |
| 376 | _check_dtype(output_nodes, idtype, "dtype") |
| 377 | _check_dtype(blocks, idtype, "idtype") |
| 378 | |
| 379 | g2 = dgl.heterograph( |
| 380 | { |
| 381 | ("user", "follow", "user"): ( |
| 382 | [0, 0, 0, 1, 1, 1, 2], |
| 383 | [1, 2, 3, 0, 2, 3, 0], |
| 384 | ), |
| 385 | ("user", "followed-by", "user"): ( |
| 386 | [1, 2, 3, 0, 2, 3, 0], |
| 387 | [0, 0, 0, 1, 1, 1, 2], |
nothing calls this directly
no test coverage detected