MCPcopy
hub / github.com/dmlc/dgl / TensorizedDataset

Class TensorizedDataset

python/dgl/dataloading/dataloader.py:191–242  ·  view source on GitHub ↗

Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. When the dataset is on the GPU, this significantly reduces the overhead.

Source from the content-addressed store, hash-verified

189
190
191class TensorizedDataset(torch.utils.data.IterableDataset):
192 """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
193 When the dataset is on the GPU, this significantly reduces the overhead.
194 """
195
196 def __init__(
197 self, indices, batch_size, drop_last, shuffle, use_shared_memory
198 ):
199 if isinstance(indices, Mapping):
200 self._mapping_keys = list(indices.keys())
201 self._device = next(iter(indices.values())).device
202 self._id_tensor = _get_id_tensor_from_mapping(
203 indices, self._device, self._mapping_keys
204 )
205 else:
206 self._id_tensor = indices
207 self._device = indices.device
208 self._mapping_keys = None
209 # Use a shared memory array to permute indices for shuffling. This is to make sure that
210 # the worker processes can see it when persistent_workers=True, where self._indices
211 # would not be duplicated every epoch.
212 self._indices = torch.arange(
213 self._id_tensor.shape[0], dtype=torch.int64
214 )
215 if use_shared_memory:
216 self._indices.share_memory_()
217 self.batch_size = batch_size
218 self.drop_last = drop_last
219 self._shuffle = shuffle
220
221 def shuffle(self):
222 """Shuffle the dataset."""
223 np.random.shuffle(self._indices.numpy())
224
225 def __iter__(self):
226 indices = _divide_by_worker(
227 self._indices, self.batch_size, self.drop_last
228 )
229 id_tensor = self._id_tensor[indices]
230 return _TensorizedDatasetIter(
231 id_tensor,
232 self.batch_size,
233 self.drop_last,
234 self._mapping_keys,
235 self._shuffle,
236 )
237
238 def __len__(self):
239 num_samples = self._id_tensor.shape[0]
240 return (
241 num_samples + (0 if self.drop_last else (self.batch_size - 1))
242 ) // self.batch_size
243
244
245def _decompose_one_dimension(length, world_size, rank, drop_last):

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected