(args)
| 48 | |
| 49 | |
| 50 | def train(args): |
| 51 | args_util.verify_training_args(args) |
| 52 | accelerator_setup.prepare_dataset_args(args, True) |
| 53 | deepspeed_utils.prepare_deepspeed_args(args) |
| 54 | setup_logging(args, reset=True) |
| 55 | |
| 56 | flux_train_utils.log_timestep_sampling_info(args) |
| 57 | |
| 58 | # backward compatibility |
| 59 | if not args.skip_cache_check: |
| 60 | args.skip_cache_check = args.skip_latents_validity_check |
| 61 | |
| 62 | if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: |
| 63 | logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled") |
| 64 | args.cache_text_encoder_outputs = True |
| 65 | |
| 66 | if args.cpu_offload_checkpointing and not args.gradient_checkpointing: |
| 67 | logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") |
| 68 | args.gradient_checkpointing = True |
| 69 | |
| 70 | if args.unsloth_offload_checkpointing: |
| 71 | if not args.gradient_checkpointing: |
| 72 | logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") |
| 73 | args.gradient_checkpointing = True |
| 74 | assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" |
| 75 | |
| 76 | assert ( |
| 77 | args.blocks_to_swap is None or args.blocks_to_swap == 0 |
| 78 | ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" |
| 79 | |
| 80 | assert ( |
| 81 | args.blocks_to_swap is None or args.blocks_to_swap == 0 |
| 82 | ) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing" |
| 83 | |
| 84 | cache_latents = args.cache_latents |
| 85 | use_dreambooth_method = args.in_json is None |
| 86 | |
| 87 | if args.seed is not None: |
| 88 | set_seed(args.seed) |
| 89 | |
| 90 | # prepare caching strategy: must be set before preparing dataset |
| 91 | if args.cache_latents: |
| 92 | latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy( |
| 93 | args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check |
| 94 | ) |
| 95 | strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) |
| 96 | |
| 97 | # prepare dataset |
| 98 | if args.dataset_class is None: |
| 99 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) |
| 100 | if args.dataset_config is not None: |
| 101 | logger.info(f"Load dataset config from {args.dataset_config}") |
| 102 | user_config = config_util.load_user_config(args.dataset_config) |
| 103 | ignored = ["train_data_dir", "in_json"] |
| 104 | if any(getattr(args, attr) is not None for attr in ignored): |
| 105 | logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored))) |
| 106 | else: |
| 107 | if use_dreambooth_method: |
no test coverage detected