(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase)
| 28 | |
| 29 | |
| 30 | def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): |
| 31 | # transform |
| 32 | transform_train = transforms.Compose( |
| 33 | [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] |
| 34 | ) |
| 35 | transform_test = transforms.ToTensor() |
| 36 | |
| 37 | # CIFAR-10 dataset |
| 38 | data_path = os.environ.get("DATA", "./data") |
| 39 | with coordinator.priority_execution(): |
| 40 | train_dataset = torchvision.datasets.CIFAR10( |
| 41 | root=data_path, train=True, transform=transform_train, download=True |
| 42 | ) |
| 43 | test_dataset = torchvision.datasets.CIFAR10( |
| 44 | root=data_path, train=False, transform=transform_test, download=True |
| 45 | ) |
| 46 | |
| 47 | # Data loader |
| 48 | train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) |
| 49 | test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) |
| 50 | return train_dataloader, test_dataloader |
| 51 | |
| 52 | |
| 53 | @torch.no_grad() |
no test coverage detected
searching dependent graphs…