Args: tokenize_strategy: TokenizeStrategy models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required tokens: List of tokens, fo
(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
)
| 169 | return hidden_states1, hidden_states2, pool2 |
| 170 | |
| 171 | def encode_tokens( |
| 172 | self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] |
| 173 | ) -> List[torch.Tensor]: |
| 174 | """ |
| 175 | Args: |
| 176 | tokenize_strategy: TokenizeStrategy |
| 177 | models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. |
| 178 | If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required |
| 179 | tokens: List of tokens, for text_encoder1 and text_encoder2 |
| 180 | """ |
| 181 | if len(models) == 2: |
| 182 | text_encoder1, text_encoder2 = models |
| 183 | unwrapped_text_encoder2 = None |
| 184 | else: |
| 185 | text_encoder1, text_encoder2, unwrapped_text_encoder2 = models |
| 186 | tokens1, tokens2 = tokens |
| 187 | sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy |
| 188 | tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 |
| 189 | |
| 190 | hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( |
| 191 | tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 |
| 192 | ) |
| 193 | return [hidden_states1, hidden_states2, pool2] |
| 194 | |
| 195 | def encode_tokens_with_weights( |
| 196 | self, |
no test coverage detected