MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / build_vila_engine

Function build_vila_engine

tensorrt_llm/tools/multimodal_builder.py:620–662  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

618
619
620def build_vila_engine(args):
621 # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo
622 sys.path.append(args.vila_path)
623 from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa
624 from transformers import AutoModel
625 model = AutoModel.from_pretrained(
626 args.model_path,
627 device_map='auto',
628 )
629
630 vision_tower = model.get_vision_tower()
631 image_processor = vision_tower.image_processor
632 raw_image = Image.new('RGB', [10, 10]) # dummy image
633 image = image_processor(images=raw_image,
634 return_tensors="pt")['pixel_values']
635 if isinstance(image, list):
636 image = image[0].unsqueeze(0)
637 image = image.to(args.device, torch.float16)
638
639 class VilaVisionWrapper(torch.nn.Module):
640
641 def __init__(self, tower, projector):
642 super().__init__()
643 self.tower = tower
644 self.projector = projector
645
646 def forward(self, image):
647 features = self.tower(image)
648 return self.projector(features)
649
650 model = AutoModel.from_pretrained(
651 args.model_path,
652 device_map='auto',
653 )
654 wrapper = VilaVisionWrapper(model.get_vision_tower().to(args.device),
655 model.mm_projector.to(args.device))
656 export_onnx(wrapper, image, f'{args.output_dir}/onnx')
657 build_trt_engine(
658 args.model_type,
659 [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
660 f'{args.output_dir}/onnx',
661 args.output_dir,
662 args.max_batch_size)
663
664
665def build_nougat_engine(args):

Callers 1

buildMethod · 0.85

Calls 7

VilaVisionWrapperClass · 0.85
export_onnxFunction · 0.85
build_trt_engineFunction · 0.85
unsqueezeMethod · 0.80
appendMethod · 0.45
from_pretrainedMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected