(args)
| 969 | |
| 970 | |
| 971 | def build_phi_engine(args): |
| 972 | logger.warning( |
| 973 | "Skipping TRT engine build for Phi-3 vision encoder. MultimodalModelRunner will use PyTorch vision encoder. Flash/SDPA attention in CLIP encoder is not compatible with torch.onnx.export and eager attention is unstable in PyTorch." |
| 974 | ) |
| 975 | |
| 976 | # Dump config.json needed by model runner |
| 977 | config_args = { |
| 978 | "builder_config": { |
| 979 | "precision": torch_dtype_to_str(torch.float16), |
| 980 | "model_type": "phi-3-vision", |
| 981 | } |
| 982 | } |
| 983 | to_json_file(config_args, args.output_dir + "/config.json") |
| 984 | return |
| 985 | |
| 986 | processor = AutoProcessor.from_pretrained(args.model_path, |
| 987 | trust_remote_code=True, |
| 988 | num_crops=16) |
| 989 | raw_image = Image.new('RGB', [10, 10]) # dummy image |
| 990 | image = processor(text="<|image_1|>\ndummy", |
| 991 | images=raw_image, |
| 992 | return_tensors="pt")['pixel_values'].to( |
| 993 | args.device, torch.float16) |
| 994 | image = image.flatten(0, 1) |
| 995 | |
| 996 | class Phi3VisionWrapper(torch.nn.Module): |
| 997 | |
| 998 | def __init__(self, vision_model): |
| 999 | super().__init__() |
| 1000 | self.vision_model = vision_model |
| 1001 | |
| 1002 | def forward(self, pixel_values): |
| 1003 | return self.vision_model.get_img_features(pixel_values).reshape( |
| 1004 | 1, pixel_values.shape[0], -1, self.vision_model.image_dim_out) |
| 1005 | |
| 1006 | model = AutoModelForCausalLM.from_pretrained(args.model_path, |
| 1007 | dtype=torch.float16, |
| 1008 | trust_remote_code=True) |
| 1009 | vision_model = model.model.vision_embed_tokens |
| 1010 | |
| 1011 | # Replace img_processor that uses flash attention with eager attention |
| 1012 | clip_config = vision_model.img_processor.config |
| 1013 | clip_config._attn_implementation = 'eager' |
| 1014 | del vision_model.img_processor |
| 1015 | vision_model.img_processor = CLIPVisionModel(clip_config).to(torch.float16) |
| 1016 | |
| 1017 | vision_model = vision_model.to(args.device) |
| 1018 | wrapper = Phi3VisionWrapper(vision_model) |
| 1019 | |
| 1020 | export_onnx(wrapper, image, f'{args.output_dir}/onnx') |
| 1021 | num_crops = processor.image_processor.num_crops |
| 1022 | build_trt_engine(args.model_type, |
| 1023 | [image.shape[1], image.shape[2], image.shape[3]], |
| 1024 | f'{args.output_dir}/onnx', args.output_dir, |
| 1025 | args.max_batch_size * (num_crops + 1)) |
| 1026 | |
| 1027 | |
| 1028 | def build_phi4mm_engine(args): |
no test coverage detected