MCPcopy
hub / github.com/dmlc/dgl / test_node_dataloader

Function test_node_dataloader

tests/python/pytorch/dataloading/test_dataloader.py:330–434  ·  view source on GitHub ↗
(idtype, sampler_name, mode, use_ddp)

Source from the content-addressed store, hash-verified

328)
329@pytest.mark.parametrize("use_ddp", [False, True])
330def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
331 if mode != "cpu" and F.ctx() == F.cpu():
332 pytest.skip("UVA and GPU sampling require a GPU.")
333 if use_ddp:
334 if os.name == "nt":
335 pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
336 dist.init_process_group(
337 "gloo" if F.ctx() == F.cpu() else "nccl",
338 "tcp://127.0.0.1:12347",
339 world_size=1,
340 rank=0,
341 )
342 g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
343 g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
344 g1.ndata["label"] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
345 if mode in ("cpu", "uva_cpu_indices"):
346 indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())
347 else:
348 indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())
349 if mode == "pure_gpu":
350 g1 = g1.to(F.cuda())
351
352 use_uva = mode.startswith("uva")
353
354 sampler = {
355 "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
356 "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
357 "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
358 "labor": dgl.dataloading.LaborSampler([3, 3]),
359 }[sampler_name]
360 for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
361 dataloader = dgl.dataloading.DataLoader(
362 g1,
363 indices,
364 sampler,
365 device=F.ctx(),
366 batch_size=g1.num_nodes(),
367 num_workers=num_workers,
368 use_uva=use_uva,
369 use_ddp=use_ddp,
370 )
371 for input_nodes, output_nodes, blocks in dataloader:
372 _check_device(input_nodes)
373 _check_device(output_nodes)
374 _check_device(blocks)
375 _check_dtype(input_nodes, idtype, "dtype")
376 _check_dtype(output_nodes, idtype, "dtype")
377 _check_dtype(blocks, idtype, "idtype")
378
379 g2 = dgl.heterograph(
380 {
381 ("user", "follow", "user"): (
382 [0, 0, 0, 1, 1, 1, 2],
383 [1, 2, 3, 0, 2, 3, 0],
384 ),
385 ("user", "followed-by", "user"): (
386 [1, 2, 3, 0, 2, 3, 0],
387 [0, 0, 0, 1, 1, 1, 2],

Callers

nothing calls this directly

Calls 12

_check_deviceFunction · 0.85
_check_dtypeFunction · 0.85
cudaMethod · 0.80
maxFunction · 0.50
ctxMethod · 0.45
cpuMethod · 0.45
astypeMethod · 0.45
graphMethod · 0.45
copy_toMethod · 0.45
num_nodesMethod · 0.45
toMethod · 0.45
nodesMethod · 0.45

Tested by

no test coverage detected