(
pipe,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation: bool = False,
)
| 713 | |
| 714 | |
| 715 | def log_validation( |
| 716 | pipe, |
| 717 | args, |
| 718 | accelerator, |
| 719 | pipeline_args, |
| 720 | epoch, |
| 721 | is_final_validation: bool = False, |
| 722 | ): |
| 723 | logger.info( |
| 724 | f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." |
| 725 | ) |
| 726 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it |
| 727 | scheduler_args = {} |
| 728 | |
| 729 | if "variance_type" in pipe.scheduler.config: |
| 730 | variance_type = pipe.scheduler.config.variance_type |
| 731 | |
| 732 | if variance_type in ["learned", "learned_range"]: |
| 733 | variance_type = "fixed_small" |
| 734 | |
| 735 | scheduler_args["variance_type"] = variance_type |
| 736 | |
| 737 | pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) |
| 738 | pipe = pipe.to(accelerator.device) |
| 739 | # pipe.set_progress_bar_config(disable=True) |
| 740 | |
| 741 | # run inference |
| 742 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None |
| 743 | |
| 744 | videos = [] |
| 745 | for _ in range(args.num_validation_videos): |
| 746 | pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0] |
| 747 | pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])]) |
| 748 | |
| 749 | image_np = VaeImageProcessor.pt_to_numpy(pt_images) |
| 750 | image_pil = VaeImageProcessor.numpy_to_pil(image_np) |
| 751 | |
| 752 | videos.append(image_pil) |
| 753 | |
| 754 | for tracker in accelerator.trackers: |
| 755 | phase_name = "test" if is_final_validation else "validation" |
| 756 | if tracker.name == "wandb": |
| 757 | video_filenames = [] |
| 758 | for i, video in enumerate(videos): |
| 759 | prompt = ( |
| 760 | pipeline_args["prompt"][:25] |
| 761 | .replace(" ", "_") |
| 762 | .replace(" ", "_") |
| 763 | .replace("'", "_") |
| 764 | .replace('"', "_") |
| 765 | .replace("/", "_") |
| 766 | ) |
| 767 | filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") |
| 768 | export_to_video(video, filename, fps=8) |
| 769 | video_filenames.append(filename) |
| 770 | |
| 771 | tracker.log( |
| 772 | { |
no test coverage detected
searching dependent graphs…