MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / build_neva_engine

Function build_neva_engine

tensorrt_llm/tools/multimodal_builder.py:772–848  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

770
771
772def build_neva_engine(args):
773 # extract NeMo checkpoint
774 with tarfile.open(args.model_path) as tar:
775 nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml"))
776 try:
777 # trained without TP
778 mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"),
779 map_location=args.device)
780 except KeyError:
781 # trained with TP
782 mp0_weights = torch.load(
783 tar.extractfile("./mp_rank_00/model_weights.ckpt"),
784 map_location=args.device)
785
786 vision_config = nemo_config["mm_cfg"]["vision_encoder"]
787
788 class VisionEncoderWrapper(torch.nn.Module):
789
790 def __init__(self, encoder, connector):
791 super().__init__()
792 self.encoder = encoder
793 self.connector = connector
794
795 def forward(self, images):
796 vision_x = self.encoder(pixel_values=images,
797 output_hidden_states=True)
798 vision_x = vision_x.hidden_states[-2]
799 vision_x = vision_x[:, 1:]
800 vision_x = self.connector(vision_x)
801 return vision_x
802
803 vision_path = vision_config["from_pretrained"]
804 joined_path = os.path.join(os.path.dirname(args.model_path),
805 os.path.basename(vision_path))
806 if os.path.isdir(joined_path):
807 vision_path = joined_path
808 encoder = AutoModel.from_pretrained(vision_path,
809 dtype=torch.bfloat16,
810 trust_remote_code=True)
811 vision_encoder = encoder.vision_model
812 hf_config = encoder.config
813 dtype = hf_config.torch_dtype
814
815 # connector
816 assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
817 vision_connector = torch.nn.Sequential(
818 torch.nn.Linear(vision_config["hidden_size"],
819 nemo_config["hidden_size"],
820 bias=True), torch.nn.GELU(),
821 torch.nn.Linear(nemo_config["hidden_size"],
822 nemo_config["hidden_size"],
823 bias=True)).to(dtype=dtype)
824
825 key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
826 for layer in range(0, 3, 2):
827 vision_connector[layer].load_state_dict({
828 'weight':
829 mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),

Callers 1

buildMethod · 0.85

Calls 7

export_onnxFunction · 0.85
build_trt_engineFunction · 0.85
loadMethod · 0.45
from_pretrainedMethod · 0.45
toMethod · 0.45
emptyMethod · 0.45

Tested by

no test coverage detected