(self, samples)
| 310 | return new_samples |
| 311 | |
| 312 | def construct_blocks(self, samples): |
| 313 | worker_info = torch.utils.data.get_worker_info() |
| 314 | if worker_info is not None: |
| 315 | worker_id, num_workers = worker_info.id, worker_info.num_workers |
| 316 | else: |
| 317 | worker_id, num_workers = 0, 1 |
| 318 | rng = random.Random((self.count * num_workers + worker_id) * self.world_size + self.rank) |
| 319 | self.count += 1 |
| 320 | token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], [] |
| 321 | source_batch, target_batch = [], [] |
| 322 | if rng.random() < self.short_seq_prob: |
| 323 | samples = self.split_samples(samples, rng) |
| 324 | rand = rng.random() |
| 325 | single_span = rand < self.single_span_prob |
| 326 | rand = 0.0 if single_span else rng.random() |
| 327 | attention_mask = [] |
| 328 | if rand < self.bert_prob: |
| 329 | mode = 'bert' |
| 330 | for sample in samples: |
| 331 | if single_span: |
| 332 | masked_lengths = [rng.choices(range(1, len(self.block_length_distribution) + 1), |
| 333 | weights=self.block_length_distribution)[0]] |
| 334 | masked_count = masked_lengths[0] |
| 335 | else: |
| 336 | masked_lengths, masked_count = [], 0 |
| 337 | while masked_count < int(self.bert_ratio * len(sample['text'])): |
| 338 | block_length = rng.choices(range(1, len(self.block_length_distribution) + 1), |
| 339 | weights=self.block_length_distribution)[0] |
| 340 | masked_lengths.append(block_length) |
| 341 | masked_count += block_length |
| 342 | if self.masked_lm: |
| 343 | sep = len(sample['text']) |
| 344 | else: |
| 345 | sep = len(sample['text']) - masked_count + len(masked_lengths) |
| 346 | data = self.generate_blank_data(sample, masked_lengths, sep, rng, task='bert') |
| 347 | if data is not None: |
| 348 | if self.encoder_decoder: |
| 349 | source_tokens, target_tokens, loss_masks = data |
| 350 | source_batch.append(source_tokens) |
| 351 | target_batch.append(target_tokens) |
| 352 | loss_mask_batch.append(loss_masks) |
| 353 | else: |
| 354 | tokens, targets, loss_masks, position_ids = data |
| 355 | token_batch.append(tokens) |
| 356 | target_batch.append(targets) |
| 357 | loss_mask_batch.append(loss_masks) |
| 358 | position_id_batch.append(position_ids) |
| 359 | attention_mask.append(sep) |
| 360 | |
| 361 | elif rand < self.bert_prob + self.gap_sentence_prob: |
| 362 | mode = 'sentence' |
| 363 | for sample in samples: |
| 364 | tokens, loss_masks = sample['text'], sample['loss_mask'] |
| 365 | sentence_spans = [] |
| 366 | last_index = 1 if tokens[0] == self.tokenizer.get_command('ENC').Id else 0 |
| 367 | for i in range(len(tokens)): |
| 368 | if self.contains_sentence_end(tokens[i]): |
| 369 | if last_index < i + 1: |
nothing calls this directly
no test coverage detected