Wraps a collate function with :func:`remove_parent_storage_columns` for serializing from PyTorch DataLoader workers.
| 720 | |
| 721 | # Make them classes to work with pickling in mp.spawn |
| 722 | class CollateWrapper(object): |
| 723 | """Wraps a collate function with :func:`remove_parent_storage_columns` for serializing |
| 724 | from PyTorch DataLoader workers. |
| 725 | """ |
| 726 | |
| 727 | def __init__(self, sample_func, g, use_uva, device): |
| 728 | self.sample_func = sample_func |
| 729 | self.g = g |
| 730 | self.use_uva = use_uva |
| 731 | self.device = device |
| 732 | |
| 733 | def __call__(self, items): |
| 734 | graph_device = getattr(self.g, "device", None) |
| 735 | if self.use_uva or (graph_device != torch.device("cpu")): |
| 736 | # Only copy the indices to the given device if in UVA mode or the graph |
| 737 | # is not on CPU. |
| 738 | items = recursive_apply(items, lambda x: x.to(self.device)) |
| 739 | batch = self.sample_func(self.g, items) |
| 740 | return recursive_apply(batch, remove_parent_storage_columns, self.g) |
| 741 | |
| 742 | |
| 743 | class WorkerInitWrapper(object): |