MCPcopy
hub / github.com/zai-org/CogVideo / main

Function main

finetune/train_cogvideox_image_to_video_lora.py:1019–1680  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

1017
1018
1019def main(args):
1020 if args.report_to == "wandb" and args.hub_token is not None:
1021 raise ValueError(
1022 "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
1023 " Please use `huggingface-cli login` to authenticate with the Hub."
1024 )
1025
1026 if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
1027 # due to pytorch#99272, MPS does not yet support bfloat16.
1028 raise ValueError(
1029 "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
1030 )
1031
1032 logging_dir = Path(args.output_dir, args.logging_dir)
1033
1034 accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
1035 ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
1036 init_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
1037 accelerator = Accelerator(
1038 gradient_accumulation_steps=args.gradient_accumulation_steps,
1039 mixed_precision=args.mixed_precision,
1040 log_with=args.report_to,
1041 project_config=accelerator_project_config,
1042 kwargs_handlers=[ddp_kwargs, init_kwargs],
1043 )
1044
1045 # Disable AMP for MPS.
1046 if torch.backends.mps.is_available():
1047 accelerator.native_amp = False
1048
1049 if args.report_to == "wandb":
1050 if not is_wandb_available():
1051 raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
1052
1053 # Make one log on every process with the configuration for debugging.
1054 logging.basicConfig(
1055 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1056 datefmt="%m/%d/%Y %H:%M:%S",
1057 level=logging.INFO,
1058 )
1059 logger.info(accelerator.state, main_process_only=False)
1060 if accelerator.is_local_main_process:
1061 transformers.utils.logging.set_verbosity_warning()
1062 diffusers.utils.logging.set_verbosity_info()
1063 else:
1064 transformers.utils.logging.set_verbosity_error()
1065 diffusers.utils.logging.set_verbosity_error()
1066
1067 # If passed along, set the training seed now.
1068 if args.seed is not None:
1069 set_seed(args.seed)
1070
1071 # Handle the repository creation
1072 if accelerator.is_main_process:
1073 if args.output_dir is not None:
1074 os.makedirs(args.output_dir, exist_ok=True)
1075
1076 if args.push_to_hub:

Calls 13

from_pretrainedMethod · 0.80
parametersMethod · 0.80
get_optimizerFunction · 0.70
VideoDatasetClass · 0.70
encode_videoFunction · 0.70
unwrap_modelFunction · 0.70
log_validationFunction · 0.70
save_model_cardFunction · 0.70
trainMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected