MCPcopy Index your code
hub / github.com/dmlc/dgl / test_gpu_sampling_DataLoader

Function test_gpu_sampling_DataLoader

tests/python/pytorch/graphbolt/test_dataloader.py:69–173  ·  view source on GitHub ↗
(
    sampler_name,
    enable_feature_fetch,
    overlap_feature_fetch,
    overlap_graph_fetch,
    asynchronous,
    num_gpu_cached_edges,
    gpu_cache_threshold,
)

Source from the content-addressed store, hash-verified

67@pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024])
68@pytest.mark.parametrize("gpu_cache_threshold", [1, 3])
69def test_gpu_sampling_DataLoader(
70 sampler_name,
71 enable_feature_fetch,
72 overlap_feature_fetch,
73 overlap_graph_fetch,
74 asynchronous,
75 num_gpu_cached_edges,
76 gpu_cache_threshold,
77):
78 N = 40
79 B = 4
80 num_layers = 2
81 itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seeds")
82 graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
83 graph = graph.pin_memory_() if overlap_graph_fetch else graph.to(F.ctx())
84 features = {}
85 keys = [
86 ("node", None, "a"),
87 ("node", None, "b"),
88 ("node", None, "c"),
89 ("edge", None, "d"),
90 ]
91 features[keys[0]] = dgl.graphbolt.TorchBasedFeature(
92 torch.randn(200, 4, pin_memory=True)
93 )
94 features[keys[1]] = dgl.graphbolt.TorchBasedFeature(
95 torch.randn(200, 4, pin_memory=True)
96 )
97 features[keys[2]] = dgl.graphbolt.TorchBasedFeature(
98 torch.randn(200, 4, device=F.ctx())
99 )
100 features[keys[3]] = dgl.graphbolt.TorchBasedFeature(
101 torch.randn(graph.total_num_edges, 1, device=F.ctx())
102 )
103 feature_store = dgl.graphbolt.BasicFeatureStore(features)
104
105 dataloaders = []
106 for i in range(2):
107 datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
108 datapipe = datapipe.copy_to(F.ctx())
109 kwargs = {
110 "overlap_fetch": overlap_graph_fetch,
111 "num_gpu_cached_edges": num_gpu_cached_edges,
112 "gpu_cache_threshold": gpu_cache_threshold,
113 "asynchronous": asynchronous,
114 }
115 if i != 0:
116 kwargs = {}
117 datapipe = getattr(dgl.graphbolt, sampler_name)(
118 datapipe,
119 graph,
120 fanouts=[torch.LongTensor([2]) for _ in range(num_layers)],
121 **kwargs
122 )
123 if enable_feature_fetch:
124 datapipe = dgl.graphbolt.FeatureFetcher(
125 datapipe,
126 feature_store,

Callers

nothing calls this directly

Calls 7

find_dpsFunction · 0.90
appendMethod · 0.80
num_layersMethod · 0.80
pin_memory_Method · 0.45
toMethod · 0.45
ctxMethod · 0.45
copy_toMethod · 0.45

Tested by

no test coverage detected