(graph, gpu_caches)
| 344 | |
| 345 | |
| 346 | def _init_gpu_caches(graph, gpu_caches): |
| 347 | if not hasattr(graph, "_gpu_caches"): |
| 348 | graph._gpu_caches = {"node": {}, "edge": {}} |
| 349 | if gpu_caches is None: |
| 350 | return |
| 351 | assert isinstance(gpu_caches, dict), "GPU cache argument should be a dict" |
| 352 | for i, frames in enumerate([graph._node_frames, graph._edge_frames]): |
| 353 | node_or_edge = ["node", "edge"][i] |
| 354 | cache_inf = gpu_caches.get(node_or_edge, {}) |
| 355 | for tid, frame in enumerate(frames): |
| 356 | type_ = [graph.ntypes, graph.canonical_etypes][i][tid] |
| 357 | for key in frame.keys(): |
| 358 | if key in cache_inf and cache_inf[key] > 0: |
| 359 | column = frame._columns[key] |
| 360 | if (key, type_) not in graph._gpu_caches[node_or_edge]: |
| 361 | cache = GPUCache( |
| 362 | cache_inf[key], |
| 363 | _numel_of_shape(column.shape), |
| 364 | graph.idtype, |
| 365 | ) |
| 366 | graph._gpu_caches[node_or_edge][key, type_] = ( |
| 367 | cache, |
| 368 | column.shape, |
| 369 | ) |
| 370 | |
| 371 | |
| 372 | def _prefetch_update_feats( |
no test coverage detected