MCPcopy Index your code
hub / github.com/THUDM/GLM / sample_span_in_document

Method sample_span_in_document

blocklm_utils.py:116–158  ·  view source on GitHub ↗
(self, tokens, masked_lengths, rng)

Source from the content-addressed store, hash-verified

114 return spans
115
116 def sample_span_in_document(self, tokens, masked_lengths, rng):
117 rng.shuffle(masked_lengths)
118 mask_spans = []
119 mask_index = 0
120 indices = [-1] + np.where(tokens == self.eod_token)[0].tolist()
121 last_index = len(tokens)
122 documents = []
123 for index in reversed(indices):
124 start_index = index
125 if start_index + 1 < len(tokens) and tokens[start_index + 1] == self.tokenizer.get_command('ENC').Id:
126 start_index += 1
127 length = last_index - start_index - 1
128 if last_index == len(tokens) and length > 0:
129 length -= 1
130 documents.append((start_index + 1, length))
131 last_index = index
132 documents.sort(key=lambda x: x[1])
133 for i, (offset, length) in enumerate(documents):
134 if i == len(documents) - 1:
135 current_masked_length, current_count = 0, 0
136 while mask_index + current_count < len(masked_lengths) and masked_lengths[
137 mask_index + current_count] + current_masked_length + current_count <= length:
138 current_masked_length += masked_lengths[mask_index + current_count]
139 current_count += 1
140 if current_count > 0:
141 spans = self.sample_spans(masked_lengths[mask_index: mask_index + current_count], length, rng,
142 offset=offset)
143 mask_spans += spans
144 if mask_index + current_count < len(masked_lengths) - 1:
145 print(length, masked_lengths[mask_index:], masked_lengths[:mask_index], indices)
146 else:
147 current_masked_total = int(length * self.bert_ratio)
148 current_masked_length, current_count = 0, 0
149 while mask_index + current_count < len(masked_lengths) and masked_lengths[
150 mask_index + current_count] + current_masked_length <= current_masked_total:
151 current_masked_length += masked_lengths[mask_index + current_count]
152 current_count += 1
153 if current_count > 0:
154 spans = self.sample_spans(masked_lengths[mask_index:mask_index + current_count], length,
155 rng, offset=offset)
156 mask_spans += spans
157 mask_index += current_count
158 return mask_spans
159
160 def make_masked_data(self, tokens, loss_masks, attention_mask, block_spans, rng, task='bert'):
161 position_ids = np.arange(len(tokens), dtype=np.long)

Callers 2

generate_blank_dataMethod · 0.95
mainFunction · 0.95

Calls 3

sample_spansMethod · 0.95
get_commandMethod · 0.80
appendMethod · 0.80

Tested by 1

mainFunction · 0.76