(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5,
gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3,
max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3,
short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False,
shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False)
| 29 | |
| 30 | class ConstructBlockStrategy: |
| 31 | def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5, |
| 32 | gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3, |
| 33 | max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3, |
| 34 | short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False, |
| 35 | shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False): |
| 36 | self.eod_token = args.eod_token |
| 37 | self.tokenizer = tokenizer |
| 38 | self.count = 0 |
| 39 | self.max_seq_length = max_seq_length |
| 40 | self.rank = mpu.get_data_parallel_rank() |
| 41 | self.world_size = mpu.get_data_parallel_world_size() |
| 42 | # self.rank = 0 |
| 43 | # self.world_size = 1 |
| 44 | assert 0.0 <= bert_prob <= 1.0 |
| 45 | self.bert_prob = bert_prob |
| 46 | self.gap_sentence_prob = gap_sentence_prob |
| 47 | self.gpt_prob = 1 - bert_prob - gap_sentence_prob |
| 48 | assert self.gpt_prob >= -1e-10 |
| 49 | self.infill_prob = gpt_infill_prob |
| 50 | self.gpt_min_ratio = gpt_min_ratio |
| 51 | self.bert_ratio = bert_ratio |
| 52 | self.gap_sentence_ratio = gap_sentence_ratio |
| 53 | self.block_length_distribution = [poisson.pmf(i, average_block_length) for i in range(1, max_block_length)] |
| 54 | self.block_mask_prob = block_mask_prob |
| 55 | self.context_mask_ratio = context_mask_ratio |
| 56 | self.context_mask_range = context_mask_range |
| 57 | self.short_seq_prob = short_seq_prob |
| 58 | self.single_span_prob = single_span_prob |
| 59 | self.block_position_encoding = block_position_encoding |
| 60 | self.encoder_decoder = encoder_decoder |
| 61 | self.shuffle_blocks = shuffle_blocks |
| 62 | self.sentinel_token = sentinel_token |
| 63 | self.generation_mask = 'gMASK' if task_mask else 'MASK' |
| 64 | self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id |
| 65 | self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK' |
| 66 | self.gap_sentence_mask = self.tokenizer.get_command(self.gap_sentence_mask).Id |
| 67 | self.random_position = random_position |
| 68 | self.masked_lm = masked_lm |
| 69 | print_rank_0( |
| 70 | f"BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}") |
| 71 | print_rank_0( |
| 72 | f"generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}") |
| 73 | print_rank_0(f"block length distribution {self.block_length_distribution}") |
| 74 | print_rank_0(f"block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}") |
| 75 | |
| 76 | def contains_sentence_end(self, tok): |
| 77 | tok = self.tokenizer.IdToToken(tok) |
nothing calls this directly
no test coverage detected