MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / parse_input

Function parse_input

examples/run.py:122–195  ·  view source on GitHub ↗
(tokenizer,
                input_text=None,
                prompt_template=None,
                input_file=None,
                add_special_tokens=True,
                max_input_length=923,
                pad_id=None,
                num_prepend_vtokens=[],
                model_name=None,
                model_version=None)

Source from the content-addressed store, hash-verified

120
121
122def parse_input(tokenizer,
123 input_text=None,
124 prompt_template=None,
125 input_file=None,
126 add_special_tokens=True,
127 max_input_length=923,
128 pad_id=None,
129 num_prepend_vtokens=[],
130 model_name=None,
131 model_version=None):
132 if pad_id is None:
133 pad_id = tokenizer.pad_token_id
134
135 batch_input_ids = []
136 if input_file is None:
137 if 'whisper' in model_name.lower():
138 batch_input_ids.append(tokenizer.prefix_tokens)
139 else:
140 for curr_text in input_text:
141 if prompt_template is not None:
142 curr_text = prompt_template.format(input_text=curr_text)
143 input_ids = tokenizer.encode(
144 curr_text,
145 add_special_tokens=add_special_tokens,
146 truncation=True,
147 max_length=max_input_length)
148 batch_input_ids.append(input_ids)
149 else:
150 if input_file.endswith('.csv'):
151 with open(input_file, 'r') as csv_file:
152 csv_reader = csv.reader(csv_file, delimiter=',')
153 for line in csv_reader:
154 input_ids = np.array(line, dtype='int32')
155 batch_input_ids.append(input_ids[-max_input_length:])
156 elif input_file.endswith('.npy'):
157 inputs = np.load(input_file)
158 for row in inputs:
159 input_ids = row[row != pad_id]
160 batch_input_ids.append(input_ids[-max_input_length:])
161
162 elif input_file.endswith('.txt'):
163 with open(input_file, 'r', encoding='utf-8',
164 errors='replace') as txt_file:
165 input_text = txt_file.readlines()
166 batch_input_ids = tokenizer(
167 input_text,
168 add_special_tokens=add_special_tokens,
169 truncation=True,
170 max_length=max_input_length)["input_ids"]
171 else:
172 print('Input file format not supported.')
173 raise SystemExit
174
175 if num_prepend_vtokens:
176 assert len(num_prepend_vtokens) == len(batch_input_ids)
177 base_vocab_size = tokenizer.vocab_size
178 for i, length in enumerate(num_prepend_vtokens):
179 batch_input_ids[i] = list(

Callers 1

mainFunction · 0.70

Calls 4

appendMethod · 0.45
encodeMethod · 0.45
loadMethod · 0.45
debugMethod · 0.45

Tested by

no test coverage detected