MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / initialize_vision_modules

Method initialize_vision_modules

llava/model/llava_arch.py:49–97  ·  view source on GitHub ↗
(self, model_args, fsdp=None)

Source from the content-addressed store, hash-verified

47 return vision_tower
48
49 def initialize_vision_modules(self, model_args, fsdp=None):
50 vision_tower = model_args.vision_tower
51 mm_vision_select_layer = model_args.mm_vision_select_layer
52 mm_vision_select_feature = model_args.mm_vision_select_feature
53 pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
54 mm_patch_merge_type = model_args.mm_patch_merge_type
55
56 self.config.mm_vision_tower = vision_tower
57
58 if self.get_vision_tower() is None:
59 vision_tower = build_vision_tower(model_args)
60
61 if fsdp is not None and len(fsdp) > 0:
62 self.vision_tower = [vision_tower]
63 else:
64 self.vision_tower = vision_tower
65 else:
66 if fsdp is not None and len(fsdp) > 0:
67 vision_tower = self.vision_tower[0]
68 else:
69 vision_tower = self.vision_tower
70 vision_tower.load_model()
71
72 self.config.use_mm_proj = True
73 self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
74 self.config.mm_hidden_size = vision_tower.hidden_size
75 self.config.mm_vision_select_layer = mm_vision_select_layer
76 self.config.mm_vision_select_feature = mm_vision_select_feature
77 self.config.mm_patch_merge_type = mm_patch_merge_type
78
79 if getattr(self, 'mm_projector', None) is None:
80 self.mm_projector = build_vision_projector(self.config)
81
82 if 'unpad' in mm_patch_merge_type:
83 embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
84 self.image_newline = nn.Parameter(
85 torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
86 )
87 else:
88 # In case it is frozen by LoRA
89 for p in self.mm_projector.parameters():
90 p.requires_grad = True
91
92 if pretrain_mm_mlp_adapter is not None:
93 mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
94 def get_w(weights, keyword):
95 return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
96
97 self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
98
99
100def unpad_image(tensor, original_size):

Callers 1

trainFunction · 0.80

Calls 4

get_vision_towerMethod · 0.95
build_vision_towerFunction · 0.85
build_vision_projectorFunction · 0.85
load_modelMethod · 0.45

Tested by

no test coverage detected