(self,
dataset,
batch_size,
pin_memory,
local_rank,
tput_timer,
collate_fn=None,
num_local_io_workers=None,
data_sampler=None,
data_parallel_world_size=None,
data_parallel_rank=None,
dataloader_drop_last=False,
deepspeed_dataloader_config={})
| 41 | class DeepSpeedDataLoader(object): |
| 42 | |
| 43 | def __init__(self, |
| 44 | dataset, |
| 45 | batch_size, |
| 46 | pin_memory, |
| 47 | local_rank, |
| 48 | tput_timer, |
| 49 | collate_fn=None, |
| 50 | num_local_io_workers=None, |
| 51 | data_sampler=None, |
| 52 | data_parallel_world_size=None, |
| 53 | data_parallel_rank=None, |
| 54 | dataloader_drop_last=False, |
| 55 | deepspeed_dataloader_config={}): |
| 56 | self.deepspeed_dataloader_config = deepspeed_dataloader_config |
| 57 | self.tput_timer = tput_timer |
| 58 | self.batch_size = batch_size |
| 59 | self.curriculum_learning_enabled = False |
| 60 | if CURRICULUM_LEARNING in deepspeed_dataloader_config: |
| 61 | self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING] |
| 62 | |
| 63 | if self.curriculum_learning_enabled: |
| 64 | data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY], |
| 65 | len(dataset), |
| 66 | self.batch_size, |
| 67 | data_parallel_rank, |
| 68 | data_parallel_world_size, |
| 69 | self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP], |
| 70 | self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS], |
| 71 | self.deepspeed_dataloader_config[GLOBAL_RANK], |
| 72 | drop_last=dataloader_drop_last) |
| 73 | device_count = get_accelerator().device_count() |
| 74 | num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS] |
| 75 | else: |
| 76 | if local_rank >= 0: |
| 77 | if data_sampler is None: |
| 78 | data_sampler = DistributedSampler(dataset=dataset, |
| 79 | num_replicas=data_parallel_world_size, |
| 80 | rank=data_parallel_rank) |
| 81 | device_count = 1 |
| 82 | else: |
| 83 | if data_sampler is None: |
| 84 | data_sampler = RandomSampler(dataset) |
| 85 | device_count = get_accelerator().device_count() |
| 86 | batch_size *= device_count |
| 87 | |
| 88 | if num_local_io_workers is None: |
| 89 | num_local_io_workers = 2 * device_count |
| 90 | |
| 91 | self.num_local_io_workers = num_local_io_workers |
| 92 | self.data_sampler = data_sampler |
| 93 | self.dataset = dataset |
| 94 | self.collate_fn = collate_fn |
| 95 | self.device_count = device_count |
| 96 | self.batch_size = batch_size |
| 97 | self.pin_memory = pin_memory |
| 98 | self.data = None |
| 99 | self.dataloader_drop_last = dataloader_drop_last |
| 100 | self.post_process_func = None |
nothing calls this directly
no test coverage detected