MCPcopy Index your code
hub / github.com/THUDM/GLM / __init__

Method __init__

blocklm_utils.py:31–74  ·  view source on GitHub ↗
(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5,
                 gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3,
                 max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3,
                 short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False,
                 shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False)

Source from the content-addressed store, hash-verified

29
30class ConstructBlockStrategy:
31 def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5,
32 gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3,
33 max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3,
34 short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False,
35 shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False):
36 self.eod_token = args.eod_token
37 self.tokenizer = tokenizer
38 self.count = 0
39 self.max_seq_length = max_seq_length
40 self.rank = mpu.get_data_parallel_rank()
41 self.world_size = mpu.get_data_parallel_world_size()
42 # self.rank = 0
43 # self.world_size = 1
44 assert 0.0 <= bert_prob <= 1.0
45 self.bert_prob = bert_prob
46 self.gap_sentence_prob = gap_sentence_prob
47 self.gpt_prob = 1 - bert_prob - gap_sentence_prob
48 assert self.gpt_prob >= -1e-10
49 self.infill_prob = gpt_infill_prob
50 self.gpt_min_ratio = gpt_min_ratio
51 self.bert_ratio = bert_ratio
52 self.gap_sentence_ratio = gap_sentence_ratio
53 self.block_length_distribution = [poisson.pmf(i, average_block_length) for i in range(1, max_block_length)]
54 self.block_mask_prob = block_mask_prob
55 self.context_mask_ratio = context_mask_ratio
56 self.context_mask_range = context_mask_range
57 self.short_seq_prob = short_seq_prob
58 self.single_span_prob = single_span_prob
59 self.block_position_encoding = block_position_encoding
60 self.encoder_decoder = encoder_decoder
61 self.shuffle_blocks = shuffle_blocks
62 self.sentinel_token = sentinel_token
63 self.generation_mask = 'gMASK' if task_mask else 'MASK'
64 self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id
65 self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK'
66 self.gap_sentence_mask = self.tokenizer.get_command(self.gap_sentence_mask).Id
67 self.random_position = random_position
68 self.masked_lm = masked_lm
69 print_rank_0(
70 f"BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}")
71 print_rank_0(
72 f"generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}")
73 print_rank_0(f"block length distribution {self.block_length_distribution}")
74 print_rank_0(f"block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}")
75
76 def contains_sentence_end(self, tok):
77 tok = self.tokenizer.IdToToken(tok)

Callers

nothing calls this directly

Calls 2

print_rank_0Function · 0.90
get_commandMethod · 0.80

Tested by

no test coverage detected