Converts a given indices tensor to a TensorizedDataset, an IterableDataset that returns views of the original tensor, to reduce overhead from having a list of scalar tensors in default PyTorch DataLoader implementation.
(
indices,
batch_size,
drop_last,
use_ddp,
ddp_seed,
shuffle,
use_shared_memory,
)
| 755 | |
| 756 | |
| 757 | def create_tensorized_dataset( |
| 758 | indices, |
| 759 | batch_size, |
| 760 | drop_last, |
| 761 | use_ddp, |
| 762 | ddp_seed, |
| 763 | shuffle, |
| 764 | use_shared_memory, |
| 765 | ): |
| 766 | """Converts a given indices tensor to a TensorizedDataset, an IterableDataset |
| 767 | that returns views of the original tensor, to reduce overhead from having |
| 768 | a list of scalar tensors in default PyTorch DataLoader implementation. |
| 769 | """ |
| 770 | if use_ddp: |
| 771 | # DDP always uses shared memory |
| 772 | return DDPTensorizedDataset( |
| 773 | indices, batch_size, drop_last, ddp_seed, shuffle |
| 774 | ) |
| 775 | else: |
| 776 | return TensorizedDataset( |
| 777 | indices, batch_size, drop_last, shuffle, use_shared_memory |
| 778 | ) |
| 779 | |
| 780 | |
| 781 | def _get_device(device): |
no test coverage detected