get_batch subdivides the source data into chunks of length args.seq_length. If source is equal to the example output of the data loading example, with a seq_length limit of 2, we'd get the following two Variables for i = 0: ┌ a g m s ┐ ┌ b h n t ┐ └ b h n t ┘ └ c i o u ┘ Not
(data_iterator, timers)
| 166 | |
| 167 | |
| 168 | def get_batch(data_iterator, timers): |
| 169 | ''' get_batch subdivides the source data into chunks of |
| 170 | length args.seq_length. If source is equal to the example |
| 171 | output of the data loading example, with a seq_length limit |
| 172 | of 2, we'd get the following two Variables for i = 0: |
| 173 | ┌ a g m s ┐ ┌ b h n t ┐ |
| 174 | └ b h n t ┘ └ c i o u ┘ |
| 175 | Note that despite the name of the function, the subdivison of data is not |
| 176 | done along the batch dimension (i.e. dimension 1), since that was handled |
| 177 | by the data loader. The chunks are along dimension 0, corresponding |
| 178 | to the seq_len dimension in the LSTM. A Variable representing an appropriate |
| 179 | shard reset mask of the same dimensions is also returned. |
| 180 | ''' |
| 181 | # Items and their type. |
| 182 | keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask'] |
| 183 | datatype = torch.int64 |
| 184 | |
| 185 | # Broadcast data. |
| 186 | timers('data loader').start() |
| 187 | if data_iterator is not None: |
| 188 | data = next(data_iterator) |
| 189 | else: |
| 190 | data = None |
| 191 | timers('data loader').stop() |
| 192 | data_b = mpu.broadcast_data(keys, data, datatype) |
| 193 | |
| 194 | # Unpack. |
| 195 | tokens = data_b['text'].long() |
| 196 | types = data_b['types'].long() |
| 197 | next_sentence = data_b['is_random'].long() |
| 198 | loss_mask = data_b['mask'].float() |
| 199 | lm_labels = data_b['mask_labels'].long() |
| 200 | padding_mask = data_b['pad_mask'].byte() |
| 201 | |
| 202 | return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask |
| 203 | |
| 204 | |
| 205 | def forward_step(data_iterator, model, args, timers): |
no test coverage detected