MCPcopy Index your code
hub / github.com/THUDM/GLM / construct_blocks

Method construct_blocks

blocklm_utils.py:312–457  ·  view source on GitHub ↗
(self, samples)

Source from the content-addressed store, hash-verified

310 return new_samples
311
312 def construct_blocks(self, samples):
313 worker_info = torch.utils.data.get_worker_info()
314 if worker_info is not None:
315 worker_id, num_workers = worker_info.id, worker_info.num_workers
316 else:
317 worker_id, num_workers = 0, 1
318 rng = random.Random((self.count * num_workers + worker_id) * self.world_size + self.rank)
319 self.count += 1
320 token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], []
321 source_batch, target_batch = [], []
322 if rng.random() < self.short_seq_prob:
323 samples = self.split_samples(samples, rng)
324 rand = rng.random()
325 single_span = rand < self.single_span_prob
326 rand = 0.0 if single_span else rng.random()
327 attention_mask = []
328 if rand < self.bert_prob:
329 mode = 'bert'
330 for sample in samples:
331 if single_span:
332 masked_lengths = [rng.choices(range(1, len(self.block_length_distribution) + 1),
333 weights=self.block_length_distribution)[0]]
334 masked_count = masked_lengths[0]
335 else:
336 masked_lengths, masked_count = [], 0
337 while masked_count < int(self.bert_ratio * len(sample['text'])):
338 block_length = rng.choices(range(1, len(self.block_length_distribution) + 1),
339 weights=self.block_length_distribution)[0]
340 masked_lengths.append(block_length)
341 masked_count += block_length
342 if self.masked_lm:
343 sep = len(sample['text'])
344 else:
345 sep = len(sample['text']) - masked_count + len(masked_lengths)
346 data = self.generate_blank_data(sample, masked_lengths, sep, rng, task='bert')
347 if data is not None:
348 if self.encoder_decoder:
349 source_tokens, target_tokens, loss_masks = data
350 source_batch.append(source_tokens)
351 target_batch.append(target_tokens)
352 loss_mask_batch.append(loss_masks)
353 else:
354 tokens, targets, loss_masks, position_ids = data
355 token_batch.append(tokens)
356 target_batch.append(targets)
357 loss_mask_batch.append(loss_masks)
358 position_id_batch.append(position_ids)
359 attention_mask.append(sep)
360
361 elif rand < self.bert_prob + self.gap_sentence_prob:
362 mode = 'sentence'
363 for sample in samples:
364 tokens, loss_masks = sample['text'], sample['loss_mask']
365 sentence_spans = []
366 last_index = 1 if tokens[0] == self.tokenizer.get_command('ENC').Id else 0
367 for i in range(len(tokens)):
368 if self.contains_sentence_end(tokens[i]):
369 if last_index < i + 1:

Callers

nothing calls this directly

Calls 9

split_samplesMethod · 0.95
generate_blank_dataMethod · 0.95
contains_sentence_endMethod · 0.95
make_block_dataMethod · 0.95
pad_batchMethod · 0.95
index_in_listFunction · 0.85
appendMethod · 0.80
get_commandMethod · 0.80
DecodeIdsMethod · 0.45

Tested by

no test coverage detected