MCPcopy
hub / github.com/hpcaitech/ColossalAI / build_dataloader

Function build_dataloader

examples/tutorial/new_api/cifar_resnet/train.py:30–50  ·  view source on GitHub ↗
(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase)

Source from the content-addressed store, hash-verified

28
29
30def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase):
31 # transform
32 transform_train = transforms.Compose(
33 [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
34 )
35 transform_test = transforms.ToTensor()
36
37 # CIFAR-10 dataset
38 data_path = os.environ.get("DATA", "./data")
39 with coordinator.priority_execution():
40 train_dataset = torchvision.datasets.CIFAR10(
41 root=data_path, train=True, transform=transform_train, download=True
42 )
43 test_dataset = torchvision.datasets.CIFAR10(
44 root=data_path, train=False, transform=transform_test, download=True
45 )
46
47 # Data loader
48 train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
49 test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
50 return train_dataloader, test_dataloader
51
52
53@torch.no_grad()

Callers 1

mainFunction · 0.70

Calls 3

priority_executionMethod · 0.80
getMethod · 0.45
prepare_dataloaderMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…