MCPcopy Index your code
hub / github.com/THUDM/GLM / make_block_data

Method make_block_data

blocklm_utils.py:172–268  ·  view source on GitHub ↗
(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert')

Source from the content-addressed store, hash-verified

170 return tokens, targets, loss_masks, position_ids
171
172 def make_block_data(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert'):
173 text_length = len(tokens)
174 position_ids = np.ones(len(tokens), dtype=np.long)
175 for start, end in block_spans:
176 position_ids[start + 1: end] = 0
177 position_ids = np.cumsum(position_ids) - 1
178 if self.random_position and position_ids[-1] < self.max_seq_length - 1:
179 position_bias = self.max_seq_length - position_ids[-1]
180 position_bias = rng.randrange(0, position_bias)
181 position_ids = position_ids + position_bias
182 if self.encoder_decoder or not self.shuffle_blocks:
183 block_spans.sort(key=lambda x: x[0])
184 else:
185 rng.shuffle(block_spans)
186 if self.sentinel_token:
187 block_spans = [(start, end, idx) for idx, (start, end) in enumerate(block_spans)]
188 else:
189 block_spans = [(start, end, 0) for start, end in block_spans]
190 target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], []
191 for start, end, idx in block_spans:
192 sop_token = 'sop' if idx == 0 else f"sop{idx}"
193 target_tokens.append([self.tokenizer.get_command(sop_token).Id])
194 span_tokens = copy.deepcopy(tokens[start: end])
195 if self.block_mask_prob > 0.0 and task == 'bert':
196 for sub_idx in range(len(span_tokens)):
197 if random.random() < self.block_mask_prob:
198 span_tokens[sub_idx] = self.tokenizer.get_command('dBLOCK').Id
199 target_tokens.append(span_tokens)
200 targets.append(tokens[start: end])
201 targets.append([self.tokenizer.get_command('eop').Id])
202 if not self.sentinel_token:
203 target_position_id = position_ids[start: end]
204 target_position_ids.append(target_position_id)
205 target_position_ids.append([target_position_id[0]])
206 else:
207 target_position_ids.append([self.max_seq_length] * (end - start + 1))
208 if self.block_position_encoding:
209 target_block_position_ids.append(np.arange(1, end - start + 2, dtype=np.long))
210 else:
211 target_block_position_ids.append([1] * (end - start + 1))
212 block_spans.sort(key=lambda x: x[0])
213 source_tokens, source_position_ids, local_spans = [], [], []
214 last, current_length = 0, 0
215 for start, end, idx in block_spans:
216 if task == 'generation':
217 mask_id = self.generation_mask
218 elif task == 'gap_sentence':
219 mask_id = self.gap_sentence_mask
220 else:
221 mask_token = 'MASK' if idx == 0 else f'MASK{idx}'
222 mask_id = self.tokenizer.get_command(mask_token).Id
223 local_spans.append((current_length, current_length + start - last))
224 source_tokens.append(tokens[last: start])
225 source_tokens.append([mask_id])
226 source_position_ids.append(position_ids[last: start])
227 source_position_ids.append([position_ids[start]])
228 current_length += start - last + 1
229 last = end

Callers 2

generate_blank_dataMethod · 0.95
construct_blocksMethod · 0.95

Calls 4

cumsumMethod · 0.80
appendMethod · 0.80
get_commandMethod · 0.80
DecodeIdsMethod · 0.45

Tested by

no test coverage detected