(args)
| 770 | |
| 771 | |
| 772 | def build_neva_engine(args): |
| 773 | # extract NeMo checkpoint |
| 774 | with tarfile.open(args.model_path) as tar: |
| 775 | nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml")) |
| 776 | try: |
| 777 | # trained without TP |
| 778 | mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"), |
| 779 | map_location=args.device) |
| 780 | except KeyError: |
| 781 | # trained with TP |
| 782 | mp0_weights = torch.load( |
| 783 | tar.extractfile("./mp_rank_00/model_weights.ckpt"), |
| 784 | map_location=args.device) |
| 785 | |
| 786 | vision_config = nemo_config["mm_cfg"]["vision_encoder"] |
| 787 | |
| 788 | class VisionEncoderWrapper(torch.nn.Module): |
| 789 | |
| 790 | def __init__(self, encoder, connector): |
| 791 | super().__init__() |
| 792 | self.encoder = encoder |
| 793 | self.connector = connector |
| 794 | |
| 795 | def forward(self, images): |
| 796 | vision_x = self.encoder(pixel_values=images, |
| 797 | output_hidden_states=True) |
| 798 | vision_x = vision_x.hidden_states[-2] |
| 799 | vision_x = vision_x[:, 1:] |
| 800 | vision_x = self.connector(vision_x) |
| 801 | return vision_x |
| 802 | |
| 803 | vision_path = vision_config["from_pretrained"] |
| 804 | joined_path = os.path.join(os.path.dirname(args.model_path), |
| 805 | os.path.basename(vision_path)) |
| 806 | if os.path.isdir(joined_path): |
| 807 | vision_path = joined_path |
| 808 | encoder = AutoModel.from_pretrained(vision_path, |
| 809 | dtype=torch.bfloat16, |
| 810 | trust_remote_code=True) |
| 811 | vision_encoder = encoder.vision_model |
| 812 | hf_config = encoder.config |
| 813 | dtype = hf_config.torch_dtype |
| 814 | |
| 815 | # connector |
| 816 | assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu" |
| 817 | vision_connector = torch.nn.Sequential( |
| 818 | torch.nn.Linear(vision_config["hidden_size"], |
| 819 | nemo_config["hidden_size"], |
| 820 | bias=True), torch.nn.GELU(), |
| 821 | torch.nn.Linear(nemo_config["hidden_size"], |
| 822 | nemo_config["hidden_size"], |
| 823 | bias=True)).to(dtype=dtype) |
| 824 | |
| 825 | key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" |
| 826 | for layer in range(0, 3, 2): |
| 827 | vision_connector[layer].load_state_dict({ |
| 828 | 'weight': |
| 829 | mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), |
no test coverage detected