MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / get_batch

Function get_batch

Megatron-LM/evaluate_gpt2.py:142–185  ·  view source on GitHub ↗

get_batch subdivides the source data into chunks of length args.seq_length. If source is equal to the example output of the data loading example, with a seq_length limit of 2, we'd get the following two Variables for i = 0: ┌ a g m s ┐ ┌ b h n t ┐ └ b h n t ┘ └ c i o u ┘

(data_iterator, args, timers)

Source from the content-addressed store, hash-verified

140 return attention_mask, loss_mask, position_ids
141
142def get_batch(data_iterator, args, timers):
143 ''' get_batch subdivides the source data into chunks of
144 length args.seq_length. If source is equal to the example
145 output of the data loading example, with a seq_length limit
146 of 2, we'd get the following two Variables for i = 0:
147 ┌ a g m s ┐ ┌ b h n t ┐
148 └ b h n t ┘ └ c i o u ┘
149 Note that despite the name of the function, the subdivison of data is not
150 done along the batch dimension (i.e. dimension 1), since that was handled
151 by the data loader. The chunks are along dimension 0, corresponding
152 to the seq_len dimension in the LSTM. A Variable representing an appropriate
153 shard reset mask of the same dimensions is also returned.
154 '''
155 # Items and their type.
156 keys = ['text', 'pad_mask']
157 datatype = torch.int64
158
159 # Broadcast data.
160 timers('data loader').start()
161 if data_iterator is not None:
162 data = next(data_iterator)
163 else:
164 data = None
165 timers('data loader').stop()
166 data_b = mpu.broadcast_data(keys, data, datatype)
167
168 # Unpack.
169 tokens_ = data_b['text'].long()
170 lm_labels = tokens_[:, 1:].contiguous()
171 tokens = tokens_[:, :-1].contiguous()
172 padding_mask = data_b['pad_mask'].byte()
173
174 # Get the masks and postition ids.
175 attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
176 tokens,
177 args.eod_token,
178 args.reset_position_ids,
179 args.reset_attention_mask)
180
181 # Convert
182 if args.fp16:
183 attention_mask = attention_mask.half()
184
185 return tokens, lm_labels, attention_mask, position_ids, padding_mask
186
187
188def forward_step(data_iterator, model, args, timers):

Callers 1

forward_stepFunction · 0.70

Calls 4

halfMethod · 0.80
startMethod · 0.45
stopMethod · 0.45

Tested by

no test coverage detected