MCPcopy
hub / github.com/THUDM/CogDL / DataLoader

Class DataLoader

cogdl/data/dataloader.py:26–65  ·  view source on GitHub ↗

r"""Data loader which merges data objects from a :class:`cogdl.data.dataset` to a mini-batch. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How may samples per batch to load. (default: :obj:`1`) shuffle (boo

Source from the content-addressed store, hash-verified

24
25
26class DataLoader(torch.utils.data.DataLoader, metaclass=GenericRecordParameters):
27 r"""Data loader which merges data objects from a
28 :class:`cogdl.data.dataset` to a mini-batch.
29
30 Args:
31 dataset (Dataset): The dataset from which to load the data.
32 batch_size (int, optional): How may samples per batch to load.
33 (default: :obj:`1`)
34 shuffle (bool, optional): If set to :obj:`True`, the data will be
35 reshuffled at every epoch (default: :obj:`True`)
36 """
37
38 def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
39 if "collate_fn" not in kwargs or kwargs["collate_fn"] is None:
40 kwargs["collate_fn"] = self.collate_fn
41
42 super(DataLoader, self).__init__(
43 dataset,
44 batch_size,
45 shuffle,
46 **kwargs,
47 )
48
49 @staticmethod
50 def collate_fn(batch):
51 item = batch[0]
52 if isinstance(item, Graph):
53 return Batch.from_data_list(batch)
54 elif isinstance(item, torch.Tensor):
55 return default_collate(batch)
56 elif isinstance(item, float):
57 return torch.tensor(batch, dtype=torch.float)
58
59 raise TypeError("DataLoader found invalid type: {}".format(type(item)))
60
61 def get_parameters(self):
62 return self.default_kwargs
63
64 def record_parameters(self, params):
65 self.default_kwargs = params

Callers 15

1graph.pyFile · 0.90
1graph_cn.pyFile · 0.90
mainFunction · 0.90
test_stepMethod · 0.90
train_wrapperMethod · 0.90
val_wrapperMethod · 0.90
test_wrapperMethod · 0.90
predict_wrapperMethod · 0.90
train_wrapperMethod · 0.90
val_wrapperMethod · 0.90
test_wrapperMethod · 0.90
predict_wrapperMethod · 0.90

Calls

no outgoing calls

Tested by 6

test_stepMethod · 0.72
test_wrapperMethod · 0.72
test_wrapperMethod · 0.72
test_wrapperMethod · 0.72
test_wrapperMethod · 0.72
test_wrapperMethod · 0.72