Create model_index.json for the pipeline.
(vae_type: str, default_image_size: int, output_path: str)
| 213 | |
| 214 | |
| 215 | def create_model_index(vae_type: str, default_image_size: int, output_path: str): |
| 216 | """Create model_index.json for the pipeline.""" |
| 217 | |
| 218 | if vae_type == "flux": |
| 219 | vae_class = "AutoencoderKL" |
| 220 | else: # dc-ae |
| 221 | vae_class = "AutoencoderDC" |
| 222 | |
| 223 | model_index = { |
| 224 | "_class_name": "PRXPipeline", |
| 225 | "_diffusers_version": "0.31.0.dev0", |
| 226 | "_name_or_path": os.path.basename(output_path), |
| 227 | "default_sample_size": default_image_size, |
| 228 | "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], |
| 229 | "text_encoder": ["prx", "T5GemmaEncoder"], |
| 230 | "tokenizer": ["transformers", "GemmaTokenizerFast"], |
| 231 | "transformer": ["diffusers", "PRXTransformer2DModel"], |
| 232 | "vae": ["diffusers", vae_class], |
| 233 | } |
| 234 | |
| 235 | model_index_path = os.path.join(output_path, "model_index.json") |
| 236 | with open(model_index_path, "w") as f: |
| 237 | json.dump(model_index, f, indent=2) |
| 238 | |
| 239 | |
| 240 | def main(args): |
no outgoing calls
no test coverage detected
searching dependent graphs…