MCPcopy
hub / github.com/dmlc/dgl / CollateWrapper

Class CollateWrapper

python/dgl/dataloading/dataloader.py:722–740  ·  view source on GitHub ↗

Wraps a collate function with :func:`remove_parent_storage_columns` for serializing from PyTorch DataLoader workers.

Source from the content-addressed store, hash-verified

720
721# Make them classes to work with pickling in mp.spawn
722class CollateWrapper(object):
723 """Wraps a collate function with :func:`remove_parent_storage_columns` for serializing
724 from PyTorch DataLoader workers.
725 """
726
727 def __init__(self, sample_func, g, use_uva, device):
728 self.sample_func = sample_func
729 self.g = g
730 self.use_uva = use_uva
731 self.device = device
732
733 def __call__(self, items):
734 graph_device = getattr(self.g, "device", None)
735 if self.use_uva or (graph_device != torch.device("cpu")):
736 # Only copy the indices to the given device if in UVA mode or the graph
737 # is not on CPU.
738 items = recursive_apply(items, lambda x: x.to(self.device))
739 batch = self.sample_func(self.g, items)
740 return recursive_apply(batch, remove_parent_storage_columns, self.g)
741
742
743class WorkerInitWrapper(object):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected