| 114 | return spans |
| 115 | |
| 116 | def sample_span_in_document(self, tokens, masked_lengths, rng): |
| 117 | rng.shuffle(masked_lengths) |
| 118 | mask_spans = [] |
| 119 | mask_index = 0 |
| 120 | indices = [-1] + np.where(tokens == self.eod_token)[0].tolist() |
| 121 | last_index = len(tokens) |
| 122 | documents = [] |
| 123 | for index in reversed(indices): |
| 124 | start_index = index |
| 125 | if start_index + 1 < len(tokens) and tokens[start_index + 1] == self.tokenizer.get_command('ENC').Id: |
| 126 | start_index += 1 |
| 127 | length = last_index - start_index - 1 |
| 128 | if last_index == len(tokens) and length > 0: |
| 129 | length -= 1 |
| 130 | documents.append((start_index + 1, length)) |
| 131 | last_index = index |
| 132 | documents.sort(key=lambda x: x[1]) |
| 133 | for i, (offset, length) in enumerate(documents): |
| 134 | if i == len(documents) - 1: |
| 135 | current_masked_length, current_count = 0, 0 |
| 136 | while mask_index + current_count < len(masked_lengths) and masked_lengths[ |
| 137 | mask_index + current_count] + current_masked_length + current_count <= length: |
| 138 | current_masked_length += masked_lengths[mask_index + current_count] |
| 139 | current_count += 1 |
| 140 | if current_count > 0: |
| 141 | spans = self.sample_spans(masked_lengths[mask_index: mask_index + current_count], length, rng, |
| 142 | offset=offset) |
| 143 | mask_spans += spans |
| 144 | if mask_index + current_count < len(masked_lengths) - 1: |
| 145 | print(length, masked_lengths[mask_index:], masked_lengths[:mask_index], indices) |
| 146 | else: |
| 147 | current_masked_total = int(length * self.bert_ratio) |
| 148 | current_masked_length, current_count = 0, 0 |
| 149 | while mask_index + current_count < len(masked_lengths) and masked_lengths[ |
| 150 | mask_index + current_count] + current_masked_length <= current_masked_total: |
| 151 | current_masked_length += masked_lengths[mask_index + current_count] |
| 152 | current_count += 1 |
| 153 | if current_count > 0: |
| 154 | spans = self.sample_spans(masked_lengths[mask_index:mask_index + current_count], length, |
| 155 | rng, offset=offset) |
| 156 | mask_spans += spans |
| 157 | mask_index += current_count |
| 158 | return mask_spans |
| 159 | |
| 160 | def make_masked_data(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert'): |
| 161 | position_ids = np.arange(len(tokens), dtype=np.long) |