r"""Data loader which merges data objects from a :class:`cogdl.data.dataset` to a mini-batch. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How may samples per batch to load. (default: :obj:`1`) shuffle (boo
| 24 | |
| 25 | |
| 26 | class DataLoader(torch.utils.data.DataLoader, metaclass=GenericRecordParameters): |
| 27 | r"""Data loader which merges data objects from a |
| 28 | :class:`cogdl.data.dataset` to a mini-batch. |
| 29 | |
| 30 | Args: |
| 31 | dataset (Dataset): The dataset from which to load the data. |
| 32 | batch_size (int, optional): How may samples per batch to load. |
| 33 | (default: :obj:`1`) |
| 34 | shuffle (bool, optional): If set to :obj:`True`, the data will be |
| 35 | reshuffled at every epoch (default: :obj:`True`) |
| 36 | """ |
| 37 | |
| 38 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): |
| 39 | if "collate_fn" not in kwargs or kwargs["collate_fn"] is None: |
| 40 | kwargs["collate_fn"] = self.collate_fn |
| 41 | |
| 42 | super(DataLoader, self).__init__( |
| 43 | dataset, |
| 44 | batch_size, |
| 45 | shuffle, |
| 46 | **kwargs, |
| 47 | ) |
| 48 | |
| 49 | @staticmethod |
| 50 | def collate_fn(batch): |
| 51 | item = batch[0] |
| 52 | if isinstance(item, Graph): |
| 53 | return Batch.from_data_list(batch) |
| 54 | elif isinstance(item, torch.Tensor): |
| 55 | return default_collate(batch) |
| 56 | elif isinstance(item, float): |
| 57 | return torch.tensor(batch, dtype=torch.float) |
| 58 | |
| 59 | raise TypeError("DataLoader found invalid type: {}".format(type(item))) |
| 60 | |
| 61 | def get_parameters(self): |
| 62 | return self.default_kwargs |
| 63 | |
| 64 | def record_parameters(self, params): |
| 65 | self.default_kwargs = params |
no outgoing calls