(self, samples, rng)
| 281 | return data |
| 282 | |
| 283 | def split_samples(self, samples, rng): |
| 284 | target_length = rng.randrange(32, self.max_seq_length - 1) |
| 285 | num_splits = (self.max_seq_length - 1) // target_length |
| 286 | new_samples = [] |
| 287 | cls_id = self.tokenizer.get_command('ENC').Id |
| 288 | eos_id = self.tokenizer.get_command('eos').Id |
| 289 | for sample in samples: |
| 290 | tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:] |
| 291 | for _ in range(num_splits): |
| 292 | if target_length >= len(tokens): |
| 293 | new_tokens, new_loss_masks = tokens, loss_masks |
| 294 | else: |
| 295 | random_start = rng.randrange(0, len(tokens) - target_length) |
| 296 | while random_start > 0 and (tokens[random_start] == eos_id or not ( |
| 297 | self.contains_sentence_end(tokens[random_start - 1]) or tokens[ |
| 298 | random_start - 1] == eos_id)): |
| 299 | random_start -= 1 |
| 300 | random_end = random_start + target_length |
| 301 | while random_end > random_start and not ( |
| 302 | self.contains_sentence_end(tokens[random_end - 1]) or tokens[random_end - 1] == eos_id): |
| 303 | random_end -= 1 |
| 304 | if random_end - random_start < target_length // 2: |
| 305 | random_end = random_start + target_length |
| 306 | new_tokens, new_loss_masks = tokens[random_start: random_end], loss_masks[random_start: random_end] |
| 307 | new_tokens = np.concatenate(([cls_id], new_tokens)) |
| 308 | new_loss_masks = np.concatenate(([0], new_loss_masks)) |
| 309 | new_samples.append({'text': new_tokens, 'loss_mask': new_loss_masks}) |
| 310 | return new_samples |
| 311 | |
| 312 | def construct_blocks(self, samples): |
| 313 | worker_info = torch.utils.data.get_worker_info() |
no test coverage detected