MCPcopy
hub / github.com/dmlc/dgl / _init_gpu_caches

Function _init_gpu_caches

python/dgl/dataloading/dataloader.py:346–369  ·  view source on GitHub ↗
(graph, gpu_caches)

Source from the content-addressed store, hash-verified

344
345
346def _init_gpu_caches(graph, gpu_caches):
347 if not hasattr(graph, "_gpu_caches"):
348 graph._gpu_caches = {"node": {}, "edge": {}}
349 if gpu_caches is None:
350 return
351 assert isinstance(gpu_caches, dict), "GPU cache argument should be a dict"
352 for i, frames in enumerate([graph._node_frames, graph._edge_frames]):
353 node_or_edge = ["node", "edge"][i]
354 cache_inf = gpu_caches.get(node_or_edge, {})
355 for tid, frame in enumerate(frames):
356 type_ = [graph.ntypes, graph.canonical_etypes][i][tid]
357 for key in frame.keys():
358 if key in cache_inf and cache_inf[key] > 0:
359 column = frame._columns[key]
360 if (key, type_) not in graph._gpu_caches[node_or_edge]:
361 cache = GPUCache(
362 cache_inf[key],
363 _numel_of_shape(column.shape),
364 graph.idtype,
365 )
366 graph._gpu_caches[node_or_edge][key, type_] = (
367 cache,
368 column.shape,
369 )
370
371
372def _prefetch_update_feats(

Callers 1

__init__Method · 0.85

Calls 4

GPUCacheClass · 0.85
_numel_of_shapeFunction · 0.85
getMethod · 0.45
keysMethod · 0.45

Tested by

no test coverage detected