MCPcopy
hub / github.com/XPixelGroup/DiffBIR / encode_text

Method encode_text

diffbir/model/open_clip/model.py:179–191  ·  view source on GitHub ↗
(self, text, normalize: bool = False)

Source from the content-addressed store, hash-verified

177 return F.normalize(features, dim=-1) if normalize else features
178
179 def encode_text(self, text, normalize: bool = False):
180 cast_dtype = self.transformer.get_cast_dtype()
181
182 x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
183
184 x = x + self.positional_embedding.to(cast_dtype)
185 x = x.permute(1, 0, 2) # NLD -> LND
186 x = self.transformer(x, attn_mask=self.attn_mask)
187 x = x.permute(1, 0, 2) # LND -> NLD
188 x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
189 # take features from the eot embedding (eot_token is the highest number in each sequence)
190 x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
191 return F.normalize(x, dim=-1) if normalize else x
192
193 def forward(
194 self,

Callers 3

forwardMethod · 0.95

Calls 1

get_cast_dtypeMethod · 0.80

Tested by

no test coverage detected