Download and save T5Gemma text encoder and tokenizer.
(output_path: str)
| 188 | |
| 189 | |
| 190 | def download_and_save_text_encoder(output_path: str): |
| 191 | """Download and save T5Gemma text encoder and tokenizer.""" |
| 192 | from transformers import GemmaTokenizerFast |
| 193 | from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel |
| 194 | |
| 195 | text_encoder_path = os.path.join(output_path, "text_encoder") |
| 196 | tokenizer_path = os.path.join(output_path, "tokenizer") |
| 197 | os.makedirs(text_encoder_path, exist_ok=True) |
| 198 | os.makedirs(tokenizer_path, exist_ok=True) |
| 199 | |
| 200 | print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") |
| 201 | t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") |
| 202 | |
| 203 | # Extract and save only the encoder |
| 204 | t5gemma_encoder = t5gemma_model.encoder |
| 205 | t5gemma_encoder.save_pretrained(text_encoder_path) |
| 206 | print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}") |
| 207 | |
| 208 | print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") |
| 209 | tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") |
| 210 | tokenizer.model_max_length = 256 |
| 211 | tokenizer.save_pretrained(tokenizer_path) |
| 212 | print(f"✓ Saved tokenizer to {tokenizer_path}") |
| 213 | |
| 214 | |
| 215 | def create_model_index(vae_type: str, default_image_size: int, output_path: str): |
no test coverage detected
searching dependent graphs…