MCPcopy
hub / github.com/hpcaitech/ColossalAI / build_dataloader

Function build_dataloader

examples/tutorial/new_api/cifar_vit/train.py:39–70  ·  view source on GitHub ↗
(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase)

Source from the content-addressed store, hash-verified

37
38
39def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase):
40 # transform
41 transform_train = transforms.Compose(
42 [
43 transforms.RandomCrop(32, padding=4),
44 transforms.RandomHorizontalFlip(),
45 transforms.ToTensor(),
46 transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
47 ]
48 )
49 transform_test = transforms.Compose(
50 [
51 transforms.Resize(32),
52 transforms.ToTensor(),
53 transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
54 ]
55 )
56
57 # CIFAR-10 dataset
58 data_path = os.environ.get("DATA", "./data")
59 with coordinator.priority_execution():
60 train_dataset = torchvision.datasets.CIFAR10(
61 root=data_path, train=True, transform=transform_train, download=True
62 )
63 test_dataset = torchvision.datasets.CIFAR10(
64 root=data_path, train=False, transform=transform_test, download=True
65 )
66
67 # Data loader
68 train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
69 test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
70 return train_dataloader, test_dataloader
71
72
73@torch.no_grad()

Callers 1

mainFunction · 0.70

Calls 3

priority_executionMethod · 0.80
getMethod · 0.45
prepare_dataloaderMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…