| 637 | image = image.to(args.device, torch.float16) |
| 638 | |
| 639 | class VilaVisionWrapper(torch.nn.Module): |
| 640 | |
| 641 | def __init__(self, tower, projector): |
| 642 | super().__init__() |
| 643 | self.tower = tower |
| 644 | self.projector = projector |
| 645 | |
| 646 | def forward(self, image): |
| 647 | features = self.tower(image) |
| 648 | return self.projector(features) |
| 649 | |
| 650 | model = AutoModel.from_pretrained( |
| 651 | args.model_path, |