(
sampler_name,
enable_feature_fetch,
overlap_feature_fetch,
overlap_graph_fetch,
asynchronous,
num_gpu_cached_edges,
gpu_cache_threshold,
)
| 67 | @pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024]) |
| 68 | @pytest.mark.parametrize("gpu_cache_threshold", [1, 3]) |
| 69 | def test_gpu_sampling_DataLoader( |
| 70 | sampler_name, |
| 71 | enable_feature_fetch, |
| 72 | overlap_feature_fetch, |
| 73 | overlap_graph_fetch, |
| 74 | asynchronous, |
| 75 | num_gpu_cached_edges, |
| 76 | gpu_cache_threshold, |
| 77 | ): |
| 78 | N = 40 |
| 79 | B = 4 |
| 80 | num_layers = 2 |
| 81 | itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seeds") |
| 82 | graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True) |
| 83 | graph = graph.pin_memory_() if overlap_graph_fetch else graph.to(F.ctx()) |
| 84 | features = {} |
| 85 | keys = [ |
| 86 | ("node", None, "a"), |
| 87 | ("node", None, "b"), |
| 88 | ("node", None, "c"), |
| 89 | ("edge", None, "d"), |
| 90 | ] |
| 91 | features[keys[0]] = dgl.graphbolt.TorchBasedFeature( |
| 92 | torch.randn(200, 4, pin_memory=True) |
| 93 | ) |
| 94 | features[keys[1]] = dgl.graphbolt.TorchBasedFeature( |
| 95 | torch.randn(200, 4, pin_memory=True) |
| 96 | ) |
| 97 | features[keys[2]] = dgl.graphbolt.TorchBasedFeature( |
| 98 | torch.randn(200, 4, device=F.ctx()) |
| 99 | ) |
| 100 | features[keys[3]] = dgl.graphbolt.TorchBasedFeature( |
| 101 | torch.randn(graph.total_num_edges, 1, device=F.ctx()) |
| 102 | ) |
| 103 | feature_store = dgl.graphbolt.BasicFeatureStore(features) |
| 104 | |
| 105 | dataloaders = [] |
| 106 | for i in range(2): |
| 107 | datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B) |
| 108 | datapipe = datapipe.copy_to(F.ctx()) |
| 109 | kwargs = { |
| 110 | "overlap_fetch": overlap_graph_fetch, |
| 111 | "num_gpu_cached_edges": num_gpu_cached_edges, |
| 112 | "gpu_cache_threshold": gpu_cache_threshold, |
| 113 | "asynchronous": asynchronous, |
| 114 | } |
| 115 | if i != 0: |
| 116 | kwargs = {} |
| 117 | datapipe = getattr(dgl.graphbolt, sampler_name)( |
| 118 | datapipe, |
| 119 | graph, |
| 120 | fanouts=[torch.LongTensor([2]) for _ in range(num_layers)], |
| 121 | **kwargs |
| 122 | ) |
| 123 | if enable_feature_fetch: |
| 124 | datapipe = dgl.graphbolt.FeatureFetcher( |
| 125 | datapipe, |
| 126 | feature_store, |
nothing calls this directly
no test coverage detected