(args)
| 432 | |
| 433 | |
| 434 | def build_pix2struct_engine(args): |
| 435 | processor = AutoProcessor.from_pretrained(args.model_path) |
| 436 | raw_image = Image.new('RGB', [10, 10]) # dummy image |
| 437 | dtype = torch.float16 |
| 438 | inputs = processor(text="dummy", images=raw_image, return_tensors="pt") |
| 439 | image = inputs['flattened_patches'].to(args.device, dtype) |
| 440 | |
| 441 | class pix2structVisionWrapper(torch.nn.Module): |
| 442 | |
| 443 | def __init__(self, encoder): |
| 444 | super().__init__() |
| 445 | self.encoder = encoder |
| 446 | |
| 447 | def forward(self, image): |
| 448 | attention_mask = (image.abs().sum(dim=-1) != 0) |
| 449 | vision_x = self.encoder.embeddings(image) |
| 450 | img_features = self.encoder.encoder(vision_x, |
| 451 | attention_mask=attention_mask) |
| 452 | img_features = self.encoder.layernorm(img_features[0]) |
| 453 | return img_features |
| 454 | |
| 455 | model = Pix2StructForConditionalGeneration.from_pretrained(args.model_path, |
| 456 | dtype=dtype) |
| 457 | |
| 458 | wrapper = pix2structVisionWrapper(model.encoder.to(args.device)) |
| 459 | # input shape: batch size, number of patches, hidden dimension |
| 460 | # attention mask shape: batch size, number of patches |
| 461 | # The number of image patches can vary depending on the image size, but it typically |
| 462 | # falls within a relatively narrow range. To improve performance, we can avoid using |
| 463 | # dynamic axis for the input patches and instead use a fixed number of patches along |
| 464 | # with an attention mask. |
| 465 | export_onnx(wrapper, (image, ), |
| 466 | f'{args.output_dir}/onnx', |
| 467 | input_names=['input'], |
| 468 | dynamic_axes={'input': { |
| 469 | 0: 'batch' |
| 470 | }}) |
| 471 | build_trt_engine( |
| 472 | args.model_type, |
| 473 | [image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension |
| 474 | f'{args.output_dir}/onnx', |
| 475 | args.output_dir, |
| 476 | args.max_batch_size, |
| 477 | dtype=torch.bfloat16) |
| 478 | |
| 479 | |
| 480 | def build_llava_engine(args): |
no test coverage detected