get_batch subdivides the source data into chunks of length args.seq_length. If source is equal to the example output of the data loading example, with a seq_length limit of 2, we'd get the following two Variables for i = 0: ┌ a g m s ┐ ┌ b h n t ┐ └ b h n t ┘ └ c i o u ┘ Not
(data_iterator, args, timers)
| 229 | |
| 230 | |
| 231 | def get_batch(data_iterator, args, timers): |
| 232 | ''' get_batch subdivides the source data into chunks of |
| 233 | length args.seq_length. If source is equal to the example |
| 234 | output of the data loading example, with a seq_length limit |
| 235 | of 2, we'd get the following two Variables for i = 0: |
| 236 | ┌ a g m s ┐ ┌ b h n t ┐ |
| 237 | └ b h n t ┘ └ c i o u ┘ |
| 238 | Note that despite the name of the function, the subdivison of data is not |
| 239 | done along the batch dimension (i.e. dimension 1), since that was handled |
| 240 | by the data loader. The chunks are along dimension 0, corresponding |
| 241 | to the seq_len dimension in the LSTM. A Variable representing an appropriate |
| 242 | shard reset mask of the same dimensions is also returned. |
| 243 | ''' |
| 244 | # Items and their type. |
| 245 | keys = ['text'] |
| 246 | datatype = torch.int64 |
| 247 | |
| 248 | # Broadcast data. |
| 249 | timers('data loader').start() |
| 250 | if data_iterator is not None: |
| 251 | data = next(data_iterator) |
| 252 | else: |
| 253 | data = None |
| 254 | timers('data loader').stop() |
| 255 | data_b = mpu.broadcast_data(keys, data, datatype) |
| 256 | |
| 257 | # Unpack. |
| 258 | tokens_ = data_b['text'].long() |
| 259 | labels = tokens_[:, 1:].contiguous() |
| 260 | tokens = tokens_[:, :-1].contiguous() |
| 261 | |
| 262 | # Get the masks and postition ids. |
| 263 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids( |
| 264 | tokens, |
| 265 | args.eod_token, |
| 266 | args.reset_position_ids, |
| 267 | args.reset_attention_mask) |
| 268 | # Convert |
| 269 | if args.fp16: |
| 270 | attention_mask = attention_mask.half() |
| 271 | |
| 272 | return tokens, labels, loss_mask, attention_mask, position_ids |
| 273 | |
| 274 | |
| 275 | def forward_step(data_iterator, model, args, timers): |
no test coverage detected