MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / build_phi_engine

Function build_phi_engine

tensorrt_llm/tools/multimodal_builder.py:971–1025  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

969
970
971def build_phi_engine(args):
972 logger.warning(
973 "Skipping TRT engine build for Phi-3 vision encoder. MultimodalModelRunner will use PyTorch vision encoder. Flash/SDPA attention in CLIP encoder is not compatible with torch.onnx.export and eager attention is unstable in PyTorch."
974 )
975
976 # Dump config.json needed by model runner
977 config_args = {
978 "builder_config": {
979 "precision": torch_dtype_to_str(torch.float16),
980 "model_type": "phi-3-vision",
981 }
982 }
983 to_json_file(config_args, args.output_dir + "/config.json")
984 return
985
986 processor = AutoProcessor.from_pretrained(args.model_path,
987 trust_remote_code=True,
988 num_crops=16)
989 raw_image = Image.new('RGB', [10, 10]) # dummy image
990 image = processor(text="<|image_1|>\ndummy",
991 images=raw_image,
992 return_tensors="pt")['pixel_values'].to(
993 args.device, torch.float16)
994 image = image.flatten(0, 1)
995
996 class Phi3VisionWrapper(torch.nn.Module):
997
998 def __init__(self, vision_model):
999 super().__init__()
1000 self.vision_model = vision_model
1001
1002 def forward(self, pixel_values):
1003 return self.vision_model.get_img_features(pixel_values).reshape(
1004 1, pixel_values.shape[0], -1, self.vision_model.image_dim_out)
1005
1006 model = AutoModelForCausalLM.from_pretrained(args.model_path,
1007 dtype=torch.float16,
1008 trust_remote_code=True)
1009 vision_model = model.model.vision_embed_tokens
1010
1011 # Replace img_processor that uses flash attention with eager attention
1012 clip_config = vision_model.img_processor.config
1013 clip_config._attn_implementation = 'eager'
1014 del vision_model.img_processor
1015 vision_model.img_processor = CLIPVisionModel(clip_config).to(torch.float16)
1016
1017 vision_model = vision_model.to(args.device)
1018 wrapper = Phi3VisionWrapper(vision_model)
1019
1020 export_onnx(wrapper, image, f'{args.output_dir}/onnx')
1021 num_crops = processor.image_processor.num_crops
1022 build_trt_engine(args.model_type,
1023 [image.shape[1], image.shape[2], image.shape[3]],
1024 f'{args.output_dir}/onnx', args.output_dir,
1025 args.max_batch_size * (num_crops + 1))
1026
1027
1028def build_phi4mm_engine(args):

Callers 1

buildMethod · 0.85

Calls 10

torch_dtype_to_strFunction · 0.90
to_json_fileFunction · 0.90
CLIPVisionModelClass · 0.85
Phi3VisionWrapperClass · 0.85
export_onnxFunction · 0.85
build_trt_engineFunction · 0.85
flattenMethod · 0.80
warningMethod · 0.45
from_pretrainedMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected