(args)
| 618 | |
| 619 | |
| 620 | def build_vila_engine(args): |
| 621 | # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo |
| 622 | sys.path.append(args.vila_path) |
| 623 | from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa |
| 624 | from transformers import AutoModel |
| 625 | model = AutoModel.from_pretrained( |
| 626 | args.model_path, |
| 627 | device_map='auto', |
| 628 | ) |
| 629 | |
| 630 | vision_tower = model.get_vision_tower() |
| 631 | image_processor = vision_tower.image_processor |
| 632 | raw_image = Image.new('RGB', [10, 10]) # dummy image |
| 633 | image = image_processor(images=raw_image, |
| 634 | return_tensors="pt")['pixel_values'] |
| 635 | if isinstance(image, list): |
| 636 | image = image[0].unsqueeze(0) |
| 637 | image = image.to(args.device, torch.float16) |
| 638 | |
| 639 | class VilaVisionWrapper(torch.nn.Module): |
| 640 | |
| 641 | def __init__(self, tower, projector): |
| 642 | super().__init__() |
| 643 | self.tower = tower |
| 644 | self.projector = projector |
| 645 | |
| 646 | def forward(self, image): |
| 647 | features = self.tower(image) |
| 648 | return self.projector(features) |
| 649 | |
| 650 | model = AutoModel.from_pretrained( |
| 651 | args.model_path, |
| 652 | device_map='auto', |
| 653 | ) |
| 654 | wrapper = VilaVisionWrapper(model.get_vision_tower().to(args.device), |
| 655 | model.mm_projector.to(args.device)) |
| 656 | export_onnx(wrapper, image, f'{args.output_dir}/onnx') |
| 657 | build_trt_engine( |
| 658 | args.model_type, |
| 659 | [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] |
| 660 | f'{args.output_dir}/onnx', |
| 661 | args.output_dir, |
| 662 | args.max_batch_size) |
| 663 | |
| 664 | |
| 665 | def build_nougat_engine(args): |
no test coverage detected