MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / Pipeline

Class Pipeline

examples/mmlu.py:230–332  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

228
229
230class Pipeline:
231
232 def __init__(self, tokenizer, model, model_name, pad_id, end_id,
233 max_attention_window_size, is_enc_dec, hf_model_dir,
234 engine_dir):
235 self.tokenizer = tokenizer
236 self.model = model
237 self.model_name = model_name
238 self.pad_id = pad_id
239 self.end_id = end_id
240 self.max_attention_window_size = max_attention_window_size
241 self.output_len = 2
242 self.is_enc_dec = is_enc_dec
243 self.decoder_start_token_id = None
244 self.engine_dir = engine_dir
245 if self.is_enc_dec:
246 self.decoder_start_token_id = AutoConfig.from_pretrained(
247 hf_model_dir).decoder_start_token_id
248
249 def __call__(self, prompt):
250 rank = tensorrt_llm.mpi_rank()
251 # Run the model in batch size 1 and beam size 1
252 inputs = self.tokenizer.encode(prompt, return_tensors="pt").squeeze(0)
253 batch_input_ids = [inputs]
254
255 # For multi-choice tasks like MMLU, we don't need to adjust following parameters
256 output_len = self.output_len
257 top_k = 1
258 top_p = 0.0
259
260 input_lengths = [x.size(0) for x in batch_input_ids]
261
262 with torch.no_grad():
263 if isinstance(self.model, nn.Module):
264 # Left padding for HF
265 max_length = max(input_lengths)
266 paddings = [
267 torch.ones(max_length - l, dtype=torch.int32) * self.pad_id
268 for l in input_lengths
269 ]
270 batch_input_ids = [
271 torch.cat([pad, x])
272 for x, pad in zip(batch_input_ids, paddings)
273 ]
274 batch_input_ids = torch.stack(batch_input_ids)
275 batch_input_ids = batch_input_ids.cuda()
276 if self.is_enc_dec:
277 batch_decoder_input_ids = torch.IntTensor(
278 [[self.decoder_start_token_id]]).to('cuda')
279 batch_decoder_input_ids = batch_decoder_input_ids.repeat(
280 (batch_input_ids.shape[0], 1))
281
282 with torch.no_grad():
283 # Use default temperature and top_k
284 outputs = self.model.generate(
285 batch_input_ids,
286 max_new_tokens=output_len,
287 top_k=top_k,

Callers 1

mainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected