MCPcopy Index your code
hub / github.com/huggingface/diffusers / log_validation

Function log_validation

examples/cogvideo/train_cogvideox_lora.py:715–783  ·  view source on GitHub ↗
(
    pipe,
    args,
    accelerator,
    pipeline_args,
    epoch,
    is_final_validation: bool = False,
)

Source from the content-addressed store, hash-verified

713
714
715def log_validation(
716 pipe,
717 args,
718 accelerator,
719 pipeline_args,
720 epoch,
721 is_final_validation: bool = False,
722):
723 logger.info(
724 f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
725 )
726 # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
727 scheduler_args = {}
728
729 if "variance_type" in pipe.scheduler.config:
730 variance_type = pipe.scheduler.config.variance_type
731
732 if variance_type in ["learned", "learned_range"]:
733 variance_type = "fixed_small"
734
735 scheduler_args["variance_type"] = variance_type
736
737 pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
738 pipe = pipe.to(accelerator.device)
739 # pipe.set_progress_bar_config(disable=True)
740
741 # run inference
742 generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
743
744 videos = []
745 for _ in range(args.num_validation_videos):
746 pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
747 pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
748
749 image_np = VaeImageProcessor.pt_to_numpy(pt_images)
750 image_pil = VaeImageProcessor.numpy_to_pil(image_np)
751
752 videos.append(image_pil)
753
754 for tracker in accelerator.trackers:
755 phase_name = "test" if is_final_validation else "validation"
756 if tracker.name == "wandb":
757 video_filenames = []
758 for i, video in enumerate(videos):
759 prompt = (
760 pipeline_args["prompt"][:25]
761 .replace(" ", "_")
762 .replace(" ", "_")
763 .replace("'", "_")
764 .replace('"', "_")
765 .replace("/", "_")
766 )
767 filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
768 export_to_video(video, filename, fps=8)
769 video_filenames.append(filename)
770
771 tracker.log(
772 {

Callers 1

mainFunction · 0.70

Calls 8

export_to_videoFunction · 0.90
free_memoryFunction · 0.90
pipeFunction · 0.85
infoMethod · 0.80
from_configMethod · 0.45
toMethod · 0.45
pt_to_numpyMethod · 0.45
numpy_to_pilMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…