MCPcopy Index your code
hub / github.com/THUDM/GLM / split_samples

Method split_samples

blocklm_utils.py:283–310  ·  view source on GitHub ↗
(self, samples, rng)

Source from the content-addressed store, hash-verified

281 return data
282
283 def split_samples(self, samples, rng):
284 target_length = rng.randrange(32, self.max_seq_length - 1)
285 num_splits = (self.max_seq_length - 1) // target_length
286 new_samples = []
287 cls_id = self.tokenizer.get_command('ENC').Id
288 eos_id = self.tokenizer.get_command('eos').Id
289 for sample in samples:
290 tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:]
291 for _ in range(num_splits):
292 if target_length >= len(tokens):
293 new_tokens, new_loss_masks = tokens, loss_masks
294 else:
295 random_start = rng.randrange(0, len(tokens) - target_length)
296 while random_start > 0 and (tokens[random_start] == eos_id or not (
297 self.contains_sentence_end(tokens[random_start - 1]) or tokens[
298 random_start - 1] == eos_id)):
299 random_start -= 1
300 random_end = random_start + target_length
301 while random_end > random_start and not (
302 self.contains_sentence_end(tokens[random_end - 1]) or tokens[random_end - 1] == eos_id):
303 random_end -= 1
304 if random_end - random_start < target_length // 2:
305 random_end = random_start + target_length
306 new_tokens, new_loss_masks = tokens[random_start: random_end], loss_masks[random_start: random_end]
307 new_tokens = np.concatenate(([cls_id], new_tokens))
308 new_loss_masks = np.concatenate(([0], new_loss_masks))
309 new_samples.append({'text': new_tokens, 'loss_mask': new_loss_masks})
310 return new_samples
311
312 def construct_blocks(self, samples):
313 worker_info = torch.utils.data.get_worker_info()

Callers 1

construct_blocksMethod · 0.95

Calls 3

contains_sentence_endMethod · 0.95
get_commandMethod · 0.80
appendMethod · 0.80

Tested by

no test coverage detected