(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase)
| 37 | |
| 38 | |
| 39 | def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): |
| 40 | # transform |
| 41 | transform_train = transforms.Compose( |
| 42 | [ |
| 43 | transforms.RandomCrop(32, padding=4), |
| 44 | transforms.RandomHorizontalFlip(), |
| 45 | transforms.ToTensor(), |
| 46 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), |
| 47 | ] |
| 48 | ) |
| 49 | transform_test = transforms.Compose( |
| 50 | [ |
| 51 | transforms.Resize(32), |
| 52 | transforms.ToTensor(), |
| 53 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), |
| 54 | ] |
| 55 | ) |
| 56 | |
| 57 | # CIFAR-10 dataset |
| 58 | data_path = os.environ.get("DATA", "./data") |
| 59 | with coordinator.priority_execution(): |
| 60 | train_dataset = torchvision.datasets.CIFAR10( |
| 61 | root=data_path, train=True, transform=transform_train, download=True |
| 62 | ) |
| 63 | test_dataset = torchvision.datasets.CIFAR10( |
| 64 | root=data_path, train=False, transform=transform_test, download=True |
| 65 | ) |
| 66 | |
| 67 | # Data loader |
| 68 | train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) |
| 69 | test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) |
| 70 | return train_dataloader, test_dataloader |
| 71 | |
| 72 | |
| 73 | @torch.no_grad() |
no test coverage detected
searching dependent graphs…