(self, text: Union[str, List[str]])
| 34 | self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g |
| 35 | |
| 36 | def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: |
| 37 | text = [text] if isinstance(text, str) else text |
| 38 | |
| 39 | l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") |
| 40 | g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") |
| 41 | t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") |
| 42 | |
| 43 | l_attn_mask = l_tokens["attention_mask"] |
| 44 | g_attn_mask = g_tokens["attention_mask"] |
| 45 | t5_attn_mask = t5_tokens["attention_mask"] |
| 46 | l_tokens = l_tokens["input_ids"] |
| 47 | g_tokens = g_tokens["input_ids"] |
| 48 | t5_tokens = t5_tokens["input_ids"] |
| 49 | |
| 50 | return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] |
| 51 | |
| 52 | |
| 53 | class Sd3TextEncodingStrategy(TextEncodingStrategy): |
no outgoing calls
no test coverage detected