(batch, dataloader, stream)
| 485 | |
| 486 | |
| 487 | def _prefetch(batch, dataloader, stream): |
| 488 | # feats has the same nested structure of batch, except that |
| 489 | # (1) each subgraph is replaced with a pair of node features and edge features, both |
| 490 | # being dictionaries whose keys are (type_id, column_name) and values are either |
| 491 | # tensors or futures. |
| 492 | # (2) each LazyFeature object is replaced with a tensor or future. |
| 493 | # (3) everything else are replaced with None. |
| 494 | # |
| 495 | # Once the futures are fetched, this function waits for them to complete by |
| 496 | # calling its wait() method. |
| 497 | if stream is not None: |
| 498 | current_stream = torch.cuda.current_stream() |
| 499 | current_stream.wait_stream(stream) |
| 500 | else: |
| 501 | current_stream = None |
| 502 | with torch.cuda.stream(stream): |
| 503 | # fetch node/edge features |
| 504 | feats = recursive_apply(batch, _prefetch_for, dataloader) |
| 505 | feats = recursive_apply(feats, _await_or_return) |
| 506 | feats = recursive_apply(feats, _record_stream, current_stream) |
| 507 | # transfer input nodes/seed nodes/subgraphs |
| 508 | batch = recursive_apply( |
| 509 | batch, lambda x: x.to(dataloader.device, non_blocking=True) |
| 510 | ) |
| 511 | batch = recursive_apply(batch, _record_stream, current_stream) |
| 512 | stream_event = stream.record_event() if stream is not None else None |
| 513 | return batch, feats, stream_event |
| 514 | |
| 515 | |
| 516 | def _assign_for(item, feat): |
no test coverage detected