MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_prx_to_diffusers.py:240–297  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

238
239
240def main(args):
241 # Validate inputs
242 if not os.path.exists(args.checkpoint_path):
243 raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
244
245 config = build_config(args.vae_type)
246
247 # Create output directory
248 os.makedirs(args.output_path, exist_ok=True)
249 print(f"✓ Output directory: {args.output_path}")
250
251 # Create transformer from checkpoint
252 transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
253
254 # Save transformer
255 transformer_path = os.path.join(args.output_path, "transformer")
256 os.makedirs(transformer_path, exist_ok=True)
257
258 # Save config
259 with open(os.path.join(transformer_path, "config.json"), "w") as f:
260 json.dump(config, f, indent=2)
261
262 # Save model weights as safetensors
263 state_dict = transformer.state_dict()
264 save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
265 print(f"✓ Saved transformer to {transformer_path}")
266
267 # Create scheduler config
268 create_scheduler_config(args.output_path, args.shift)
269
270 download_and_save_vae(args.vae_type, args.output_path)
271 download_and_save_text_encoder(args.output_path)
272
273 # Create model_index.json
274 create_model_index(args.vae_type, args.resolution, args.output_path)
275
276 # Verify the pipeline can be loaded
277 try:
278 pipeline = PRXPipeline.from_pretrained(args.output_path)
279 print("Pipeline loaded successfully!")
280 print(f"Transformer: {type(pipeline.transformer).__name__}")
281 print(f"VAE: {type(pipeline.vae).__name__}")
282 print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
283 print(f"Scheduler: {type(pipeline.scheduler).__name__}")
284
285 # Display model info
286 num_params = sum(p.numel() for p in pipeline.transformer.parameters())
287 print(f"✓ Transformer parameters: {num_params:,}")
288
289 except Exception as e:
290 print(f"Pipeline verification failed: {e}")
291 return False
292
293 print("Conversion completed successfully!")
294 print(f"Converted pipeline saved to: {args.output_path}")
295 print(f"VAE type: {args.vae_type}")
296
297 return True

Callers 1

Calls 10

build_configFunction · 0.85
create_scheduler_configFunction · 0.85
download_and_save_vaeFunction · 0.85
create_model_indexFunction · 0.85
existsMethod · 0.80
parametersMethod · 0.80
state_dictMethod · 0.45
from_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…