MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeed / __init__

Method __init__

deepspeed/runtime/dataloader.py:43–106  ·  view source on GitHub ↗
(self,
                 dataset,
                 batch_size,
                 pin_memory,
                 local_rank,
                 tput_timer,
                 collate_fn=None,
                 num_local_io_workers=None,
                 data_sampler=None,
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
                 dataloader_drop_last=False,
                 deepspeed_dataloader_config={})

Source from the content-addressed store, hash-verified

41class DeepSpeedDataLoader(object):
42
43 def __init__(self,
44 dataset,
45 batch_size,
46 pin_memory,
47 local_rank,
48 tput_timer,
49 collate_fn=None,
50 num_local_io_workers=None,
51 data_sampler=None,
52 data_parallel_world_size=None,
53 data_parallel_rank=None,
54 dataloader_drop_last=False,
55 deepspeed_dataloader_config={}):
56 self.deepspeed_dataloader_config = deepspeed_dataloader_config
57 self.tput_timer = tput_timer
58 self.batch_size = batch_size
59 self.curriculum_learning_enabled = False
60 if CURRICULUM_LEARNING in deepspeed_dataloader_config:
61 self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING]
62
63 if self.curriculum_learning_enabled:
64 data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY],
65 len(dataset),
66 self.batch_size,
67 data_parallel_rank,
68 data_parallel_world_size,
69 self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
70 self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
71 self.deepspeed_dataloader_config[GLOBAL_RANK],
72 drop_last=dataloader_drop_last)
73 device_count = get_accelerator().device_count()
74 num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS]
75 else:
76 if local_rank >= 0:
77 if data_sampler is None:
78 data_sampler = DistributedSampler(dataset=dataset,
79 num_replicas=data_parallel_world_size,
80 rank=data_parallel_rank)
81 device_count = 1
82 else:
83 if data_sampler is None:
84 data_sampler = RandomSampler(dataset)
85 device_count = get_accelerator().device_count()
86 batch_size *= device_count
87
88 if num_local_io_workers is None:
89 num_local_io_workers = 2 * device_count
90
91 self.num_local_io_workers = num_local_io_workers
92 self.data_sampler = data_sampler
93 self.dataset = dataset
94 self.collate_fn = collate_fn
95 self.device_count = device_count
96 self.batch_size = batch_size
97 self.pin_memory = pin_memory
98 self.data = None
99 self.dataloader_drop_last = dataloader_drop_last
100 self.post_process_func = None

Callers

nothing calls this directly

Calls 3

get_acceleratorFunction · 0.90
device_countMethod · 0.45

Tested by

no test coverage detected