MCPcopy Index your code
hub / github.com/THUDM/GLM / build_input_from_ids

Function build_input_from_ids

tasks/data_utils.py:144–230  ·  view source on GitHub ↗
(text_a_ids, text_b_ids, answer_ids, max_seq_length, tokenizer, args=None, add_cls=True,
                         add_sep=False, add_piece=False, add_eos=True, mask_id=None)

Source from the content-addressed store, hash-verified

142
143
144def build_input_from_ids(text_a_ids, text_b_ids, answer_ids, max_seq_length, tokenizer, args=None, add_cls=True,
145 add_sep=False, add_piece=False, add_eos=True, mask_id=None):
146 if mask_id is None:
147 mask_id = tokenizer.get_command('MASK').Id
148 eos_id = tokenizer.get_command('eos').Id
149 cls_id = tokenizer.get_command('ENC').Id
150 sep_id = tokenizer.get_command('sep').Id
151 ids = []
152 types = []
153 paddings = []
154 # CLS
155 if add_cls:
156 ids.append(cls_id)
157 types.append(0)
158 paddings.append(1)
159 # A
160 len_text_a = len(text_a_ids)
161 ids.extend(text_a_ids)
162 types.extend([0] * len_text_a)
163 paddings.extend([1] * len_text_a)
164 # B
165 if text_b_ids is not None:
166 # SEP
167 if add_sep:
168 ids.append(sep_id)
169 types.append(0)
170 paddings.append(1)
171 len_text_b = len(text_b_ids)
172 ids.extend(text_b_ids)
173 types.extend([1] * len_text_b)
174 paddings.extend([1] * len_text_b)
175 eos_length = 1 if add_eos else 0
176 # Cap the size.
177 if len(ids) >= max_seq_length - eos_length:
178 max_seq_length_m1 = max_seq_length - 1
179 ids = ids[0:max_seq_length_m1]
180 types = types[0:max_seq_length_m1]
181 paddings = paddings[0:max_seq_length_m1]
182 end_type = 0 if text_b_ids is None else 1
183 if add_eos:
184 ids.append(eos_id)
185 types.append(end_type)
186 paddings.append(1)
187 sep = len(ids)
188 target_ids = [0] * len(ids)
189 loss_masks = [0] * len(ids)
190 position_ids = list(range(len(ids)))
191 block_position_ids = [0] * len(ids)
192 # Piece
193 if add_piece or answer_ids is not None:
194 sop_id = tokenizer.get_command('sop').Id
195 assert mask_id in ids
196 mask_position = len(ids) - ids[-1::-1].index(
197 mask_id) - 1 if not args.sentinel_token else args.max_position_embeddings
198 ids.append(sop_id)
199 types.append(end_type)
200 paddings.append(1)
201 position_ids.append(mask_position)

Callers 9

encodeMethod · 0.90
encodeMethod · 0.90
encodeMethod · 0.90
__getitem__Method · 0.90
encodeMethod · 0.90
encodeMethod · 0.90
encodeMethod · 0.90
__getitem__Method · 0.90
__getitem__Method · 0.90

Calls 3

get_commandMethod · 0.80
appendMethod · 0.80
extendMethod · 0.80

Tested by

no test coverage detected