| 120 | |
| 121 | |
| 122 | def parse_input(tokenizer, |
| 123 | input_text=None, |
| 124 | prompt_template=None, |
| 125 | input_file=None, |
| 126 | add_special_tokens=True, |
| 127 | max_input_length=923, |
| 128 | pad_id=None, |
| 129 | num_prepend_vtokens=[], |
| 130 | model_name=None, |
| 131 | model_version=None): |
| 132 | if pad_id is None: |
| 133 | pad_id = tokenizer.pad_token_id |
| 134 | |
| 135 | batch_input_ids = [] |
| 136 | if input_file is None: |
| 137 | if 'whisper' in model_name.lower(): |
| 138 | batch_input_ids.append(tokenizer.prefix_tokens) |
| 139 | else: |
| 140 | for curr_text in input_text: |
| 141 | if prompt_template is not None: |
| 142 | curr_text = prompt_template.format(input_text=curr_text) |
| 143 | input_ids = tokenizer.encode( |
| 144 | curr_text, |
| 145 | add_special_tokens=add_special_tokens, |
| 146 | truncation=True, |
| 147 | max_length=max_input_length) |
| 148 | batch_input_ids.append(input_ids) |
| 149 | else: |
| 150 | if input_file.endswith('.csv'): |
| 151 | with open(input_file, 'r') as csv_file: |
| 152 | csv_reader = csv.reader(csv_file, delimiter=',') |
| 153 | for line in csv_reader: |
| 154 | input_ids = np.array(line, dtype='int32') |
| 155 | batch_input_ids.append(input_ids[-max_input_length:]) |
| 156 | elif input_file.endswith('.npy'): |
| 157 | inputs = np.load(input_file) |
| 158 | for row in inputs: |
| 159 | input_ids = row[row != pad_id] |
| 160 | batch_input_ids.append(input_ids[-max_input_length:]) |
| 161 | |
| 162 | elif input_file.endswith('.txt'): |
| 163 | with open(input_file, 'r', encoding='utf-8', |
| 164 | errors='replace') as txt_file: |
| 165 | input_text = txt_file.readlines() |
| 166 | batch_input_ids = tokenizer( |
| 167 | input_text, |
| 168 | add_special_tokens=add_special_tokens, |
| 169 | truncation=True, |
| 170 | max_length=max_input_length)["input_ids"] |
| 171 | else: |
| 172 | print('Input file format not supported.') |
| 173 | raise SystemExit |
| 174 | |
| 175 | if num_prepend_vtokens: |
| 176 | assert len(num_prepend_vtokens) == len(batch_input_ids) |
| 177 | base_vocab_size = tokenizer.vocab_size |
| 178 | for i, length in enumerate(num_prepend_vtokens): |
| 179 | batch_input_ids[i] = list( |