MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / build_pix2struct_engine

Function build_pix2struct_engine

tensorrt_llm/tools/multimodal_builder.py:434–477  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

432
433
434def build_pix2struct_engine(args):
435 processor = AutoProcessor.from_pretrained(args.model_path)
436 raw_image = Image.new('RGB', [10, 10]) # dummy image
437 dtype = torch.float16
438 inputs = processor(text="dummy", images=raw_image, return_tensors="pt")
439 image = inputs['flattened_patches'].to(args.device, dtype)
440
441 class pix2structVisionWrapper(torch.nn.Module):
442
443 def __init__(self, encoder):
444 super().__init__()
445 self.encoder = encoder
446
447 def forward(self, image):
448 attention_mask = (image.abs().sum(dim=-1) != 0)
449 vision_x = self.encoder.embeddings(image)
450 img_features = self.encoder.encoder(vision_x,
451 attention_mask=attention_mask)
452 img_features = self.encoder.layernorm(img_features[0])
453 return img_features
454
455 model = Pix2StructForConditionalGeneration.from_pretrained(args.model_path,
456 dtype=dtype)
457
458 wrapper = pix2structVisionWrapper(model.encoder.to(args.device))
459 # input shape: batch size, number of patches, hidden dimension
460 # attention mask shape: batch size, number of patches
461 # The number of image patches can vary depending on the image size, but it typically
462 # falls within a relatively narrow range. To improve performance, we can avoid using
463 # dynamic axis for the input patches and instead use a fixed number of patches along
464 # with an attention mask.
465 export_onnx(wrapper, (image, ),
466 f'{args.output_dir}/onnx',
467 input_names=['input'],
468 dynamic_axes={'input': {
469 0: 'batch'
470 }})
471 build_trt_engine(
472 args.model_type,
473 [image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension
474 f'{args.output_dir}/onnx',
475 args.output_dir,
476 args.max_batch_size,
477 dtype=torch.bfloat16)
478
479
480def build_llava_engine(args):

Callers 1

buildMethod · 0.85

Calls 5

export_onnxFunction · 0.85
build_trt_engineFunction · 0.85
from_pretrainedMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected