| 50 | return vision_tower |
| 51 | |
| 52 | def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, |
| 53 | pretrain_mm_mlp_adapter=None, fsdp=None): |
| 54 | self.config.mm_vision_tower = vision_tower |
| 55 | |
| 56 | image_processor = CLIPImageProcessor.from_pretrained(vision_tower) |
| 57 | |
| 58 | if not hasattr(self, 'vision_tower'): |
| 59 | vision_tower = CLIPVisionModel.from_pretrained(vision_tower) |
| 60 | else: |
| 61 | vision_tower = self.vision_tower[0] |
| 62 | vision_tower.requires_grad_(False) |
| 63 | |
| 64 | if fsdp is not None and len(fsdp) > 0: |
| 65 | self.vision_tower = [vision_tower] |
| 66 | else: |
| 67 | self.vision_tower = vision_tower |
| 68 | |
| 69 | vision_config = vision_tower.config |
| 70 | num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 |
| 71 | |
| 72 | self.config.use_mm_proj = True |
| 73 | self.config.mm_hidden_size = vision_config.hidden_size |
| 74 | self.config.mm_vision_select_layer = mm_vision_select_layer |
| 75 | |
| 76 | if not hasattr(self, 'mm_projector'): |
| 77 | self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size) |
| 78 | |
| 79 | if pretrain_mm_mlp_adapter is not None: |
| 80 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') |
| 81 | self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) |
| 82 | |
| 83 | return dict( |
| 84 | image_processor=image_processor, |
| 85 | image_token_len=num_patches, |
| 86 | vision_config=vision_config |
| 87 | ) |
| 88 | |
| 89 | def forward( |
| 90 | self, |