| 106 | |
| 107 | |
| 108 | def parse_input(tokenizer, |
| 109 | input_text=None, |
| 110 | prompt_template=None, |
| 111 | add_special_tokens=True, |
| 112 | max_input_length=923, |
| 113 | pad_id=None, |
| 114 | num_prepend_vtokens=[], |
| 115 | model_name=None, |
| 116 | model_version=None): |
| 117 | if pad_id is None: |
| 118 | pad_id = tokenizer.pad_token_id |
| 119 | |
| 120 | batch_input_ids = [] |
| 121 | for curr_text in input_text: |
| 122 | if prompt_template is not None: |
| 123 | curr_text = prompt_template.format(input_text=curr_text) |
| 124 | input_ids = tokenizer.encode(curr_text, |
| 125 | add_special_tokens=add_special_tokens, |
| 126 | truncation=True, |
| 127 | max_length=max_input_length) |
| 128 | batch_input_ids.append(input_ids) |
| 129 | |
| 130 | if num_prepend_vtokens: |
| 131 | assert len(num_prepend_vtokens) == len(batch_input_ids) |
| 132 | base_vocab_size = tokenizer.vocab_size - len( |
| 133 | tokenizer.special_tokens_map.get('additional_special_tokens', [])) |
| 134 | for i, length in enumerate(num_prepend_vtokens): |
| 135 | batch_input_ids[i] = list( |
| 136 | range(base_vocab_size, |
| 137 | base_vocab_size + length)) + batch_input_ids[i] |
| 138 | |
| 139 | if 'GLM' in model_name and model_version == 'glm': |
| 140 | for ids in batch_input_ids: |
| 141 | ids.append(tokenizer.sop_token_id) |
| 142 | |
| 143 | batch_input_ids = [ |
| 144 | torch.tensor(x, dtype=torch.int32) for x in batch_input_ids |
| 145 | ] |
| 146 | return batch_input_ids |
| 147 | |
| 148 | |
| 149 | def main(args): |