(args)
| 238 | |
| 239 | |
| 240 | def main(args): |
| 241 | # Validate inputs |
| 242 | if not os.path.exists(args.checkpoint_path): |
| 243 | raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") |
| 244 | |
| 245 | config = build_config(args.vae_type) |
| 246 | |
| 247 | # Create output directory |
| 248 | os.makedirs(args.output_path, exist_ok=True) |
| 249 | print(f"✓ Output directory: {args.output_path}") |
| 250 | |
| 251 | # Create transformer from checkpoint |
| 252 | transformer = create_transformer_from_checkpoint(args.checkpoint_path, config) |
| 253 | |
| 254 | # Save transformer |
| 255 | transformer_path = os.path.join(args.output_path, "transformer") |
| 256 | os.makedirs(transformer_path, exist_ok=True) |
| 257 | |
| 258 | # Save config |
| 259 | with open(os.path.join(transformer_path, "config.json"), "w") as f: |
| 260 | json.dump(config, f, indent=2) |
| 261 | |
| 262 | # Save model weights as safetensors |
| 263 | state_dict = transformer.state_dict() |
| 264 | save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) |
| 265 | print(f"✓ Saved transformer to {transformer_path}") |
| 266 | |
| 267 | # Create scheduler config |
| 268 | create_scheduler_config(args.output_path, args.shift) |
| 269 | |
| 270 | download_and_save_vae(args.vae_type, args.output_path) |
| 271 | download_and_save_text_encoder(args.output_path) |
| 272 | |
| 273 | # Create model_index.json |
| 274 | create_model_index(args.vae_type, args.resolution, args.output_path) |
| 275 | |
| 276 | # Verify the pipeline can be loaded |
| 277 | try: |
| 278 | pipeline = PRXPipeline.from_pretrained(args.output_path) |
| 279 | print("Pipeline loaded successfully!") |
| 280 | print(f"Transformer: {type(pipeline.transformer).__name__}") |
| 281 | print(f"VAE: {type(pipeline.vae).__name__}") |
| 282 | print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") |
| 283 | print(f"Scheduler: {type(pipeline.scheduler).__name__}") |
| 284 | |
| 285 | # Display model info |
| 286 | num_params = sum(p.numel() for p in pipeline.transformer.parameters()) |
| 287 | print(f"✓ Transformer parameters: {num_params:,}") |
| 288 | |
| 289 | except Exception as e: |
| 290 | print(f"Pipeline verification failed: {e}") |
| 291 | return False |
| 292 | |
| 293 | print("Conversion completed successfully!") |
| 294 | print(f"Converted pipeline saved to: {args.output_path}") |
| 295 | print(f"VAE type: {args.vae_type}") |
| 296 | |
| 297 | return True |
no test coverage detected
searching dependent graphs…