MCPcopy
hub / github.com/NVlabs/RADIO / CLIPWrapper

Class CLIPWrapper

examples/common/model_loader.py:68–126  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

66
67
68class CLIPWrapper(nn.Module):
69 def __init__(self, clip_model: nn.Module, tokenizer, adaptor_name: str, clip_mode: bool = False):
70 super().__init__()
71 self.inner = clip_model
72 if hasattr(clip_model, 'visual'):
73 clip_model.visual.output_tokens = True
74 self.tokenizer = tokenizer
75 self.adaptor_name = adaptor_name
76
77 if not clip_mode and hasattr(clip_model, 'visual') and hasattr(clip_model.visual, 'proj'):
78 visual = clip_model.visual
79 proj = visual.proj
80 I = torch.eye(proj.shape[0], dtype=proj.dtype, device=proj.device)
81 visual.proj = nn.Parameter(I)
82
83 @property
84 def patch_size(self):
85 return self.inner.visual.patch_size[0]
86
87 @property
88 def vision_encoder(self):
89 return self.inner.visual
90
91 def forward(self, *args, **kwargs):
92 enc = self.inner.visual(*args, **kwargs)
93
94 if isinstance(enc, (tuple, list)):
95 token, features = enc
96 else:
97 token, features = enc, None
98
99 return self._wrap_output(token, features)
100
101 def _wrap_output(self, token, features):
102 op = RadioOutput(token, features)
103
104 if self.adaptor_name:
105 return {
106 'backbone': op,
107 self.adaptor_name: op,
108 }
109 return op
110
111 def encode_image(self, image, normalize: bool = False):
112 token, _ = self(image)
113
114 if normalize:
115 token = F.normalize(token, dim=-1)
116
117 return token
118
119 def encode_text(self, text, normalize: bool = False, **kwargs):
120 try:
121 return self.inner.encode_text(text, normalize=normalize)
122 except TypeError:
123 ret = self.inner.encode_text(text)
124 if normalize:
125 ret = F.normalize(ret, dim=-1)

Callers 1

load_modelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…