MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeedExamples / GPT2Dataset

Class GPT2Dataset

Megatron-LM/gpt2_data_loader.py:82–196  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

80
81
82class GPT2Dataset(Dataset):
83
84 def __init__(self, data_path, sizes_filename, seq_length,
85 initial_seed, max_epochs=100):
86 # Input parameters.
87 self.data_path = data_path
88 self.sizes_filename = sizes_filename
89 self.seq_length = seq_length
90 self.initial_seed = initial_seed
91 self.max_epochs = max_epochs
92 # Lock for building the dataset.
93 self.lock = Lock()
94
95 # Shard stuff.
96 # Dictionary from shard nameto its size (number of element).
97 self.master_shard_size_dict = None
98 # Dictionary from shard name to modified size so it is
99 # divisible by self.seq_length.
100 self.shard_size_dict = None
101 # Long array (self.max_epochs * num-shards) populated
102 # randomly with shard names.
103 self.shards_name = None
104 # Start index of the data for a shard.
105 self.shards_start_index = None
106 self.build_shard_mappings_()
107 self.data_length = self.shards_start_index[-1]
108
109 # Data.
110 self.shards_data = [None]*self.shards_name.size
111 self.shards_sample_index = [None]*self.shards_name.size
112
113 def __len__(self):
114 return self.data_length
115
116 def __getitem__(self, idx):
117 # Find which shard we need.
118 shard_index = np.searchsorted(self.shards_start_index,
119 idx, side='right') - 1
120 # data index in the shard.
121 data_idx = idx - self.shards_start_index[shard_index]
122 # Load the shard if it is not in memory.
123 #self.lock.acquire()
124 if self.shards_data[shard_index] is None:
125 print('global rank {} is building data for shard index {} ...'.
126 format(torch.distributed.get_rank(), shard_index))
127 self.build_dataset_(shard_index)
128 #assert self.shards_data[shard_index] is not None
129 #self.lock.release()
130 # Start index.
131 start_index = self.shards_sample_index[shard_index][data_idx]
132 # Add one for label shift.
133 end_index = start_index + self.seq_length + 1
134 data = self.shards_data[shard_index][start_index:end_index]
135 return {'text': np.array(data, dtype=np.int64)}
136
137 def build_dataset_(self, shard_index):
138 # Garbage collect so we don't use a lot of memory.
139 # Leave the last one in case other threads have not catche up yet.

Callers 1

make_data_loader_Function · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected