(self, text, normalize: bool = False)
| 177 | return F.normalize(features, dim=-1) if normalize else features |
| 178 | |
| 179 | def encode_text(self, text, normalize: bool = False): |
| 180 | cast_dtype = self.transformer.get_cast_dtype() |
| 181 | |
| 182 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] |
| 183 | |
| 184 | x = x + self.positional_embedding.to(cast_dtype) |
| 185 | x = x.permute(1, 0, 2) # NLD -> LND |
| 186 | x = self.transformer(x, attn_mask=self.attn_mask) |
| 187 | x = x.permute(1, 0, 2) # LND -> NLD |
| 188 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] |
| 189 | # take features from the eot embedding (eot_token is the highest number in each sequence) |
| 190 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
| 191 | return F.normalize(x, dim=-1) if normalize else x |
| 192 | |
| 193 | def forward( |
| 194 | self, |
no test coverage detected