(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert')
| 170 | return tokens, targets, loss_masks, position_ids |
| 171 | |
| 172 | def make_block_data(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert'): |
| 173 | text_length = len(tokens) |
| 174 | position_ids = np.ones(len(tokens), dtype=np.long) |
| 175 | for start, end in block_spans: |
| 176 | position_ids[start + 1: end] = 0 |
| 177 | position_ids = np.cumsum(position_ids) - 1 |
| 178 | if self.random_position and position_ids[-1] < self.max_seq_length - 1: |
| 179 | position_bias = self.max_seq_length - position_ids[-1] |
| 180 | position_bias = rng.randrange(0, position_bias) |
| 181 | position_ids = position_ids + position_bias |
| 182 | if self.encoder_decoder or not self.shuffle_blocks: |
| 183 | block_spans.sort(key=lambda x: x[0]) |
| 184 | else: |
| 185 | rng.shuffle(block_spans) |
| 186 | if self.sentinel_token: |
| 187 | block_spans = [(start, end, idx) for idx, (start, end) in enumerate(block_spans)] |
| 188 | else: |
| 189 | block_spans = [(start, end, 0) for start, end in block_spans] |
| 190 | target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], [] |
| 191 | for start, end, idx in block_spans: |
| 192 | sop_token = 'sop' if idx == 0 else f"sop{idx}" |
| 193 | target_tokens.append([self.tokenizer.get_command(sop_token).Id]) |
| 194 | span_tokens = copy.deepcopy(tokens[start: end]) |
| 195 | if self.block_mask_prob > 0.0 and task == 'bert': |
| 196 | for sub_idx in range(len(span_tokens)): |
| 197 | if random.random() < self.block_mask_prob: |
| 198 | span_tokens[sub_idx] = self.tokenizer.get_command('dBLOCK').Id |
| 199 | target_tokens.append(span_tokens) |
| 200 | targets.append(tokens[start: end]) |
| 201 | targets.append([self.tokenizer.get_command('eop').Id]) |
| 202 | if not self.sentinel_token: |
| 203 | target_position_id = position_ids[start: end] |
| 204 | target_position_ids.append(target_position_id) |
| 205 | target_position_ids.append([target_position_id[0]]) |
| 206 | else: |
| 207 | target_position_ids.append([self.max_seq_length] * (end - start + 1)) |
| 208 | if self.block_position_encoding: |
| 209 | target_block_position_ids.append(np.arange(1, end - start + 2, dtype=np.long)) |
| 210 | else: |
| 211 | target_block_position_ids.append([1] * (end - start + 1)) |
| 212 | block_spans.sort(key=lambda x: x[0]) |
| 213 | source_tokens, source_position_ids, local_spans = [], [], [] |
| 214 | last, current_length = 0, 0 |
| 215 | for start, end, idx in block_spans: |
| 216 | if task == 'generation': |
| 217 | mask_id = self.generation_mask |
| 218 | elif task == 'gap_sentence': |
| 219 | mask_id = self.gap_sentence_mask |
| 220 | else: |
| 221 | mask_token = 'MASK' if idx == 0 else f'MASK{idx}' |
| 222 | mask_id = self.tokenizer.get_command(mask_token).Id |
| 223 | local_spans.append((current_length, current_length + start - last)) |
| 224 | source_tokens.append(tokens[last: start]) |
| 225 | source_tokens.append([mask_id]) |
| 226 | source_position_ids.append(position_ids[last: start]) |
| 227 | source_position_ids.append([position_ids[start]]) |
| 228 | current_length += start - last + 1 |
| 229 | last = end |
no test coverage detected