(args)
| 104 | |
| 105 | |
| 106 | def train(args): |
| 107 | args_util.verify_training_args(args) |
| 108 | accelerator_setup.prepare_dataset_args(args, True) |
| 109 | sdxl_train_util.verify_sdxl_training_args(args) |
| 110 | deepspeed_utils.prepare_deepspeed_args(args) |
| 111 | setup_logging(args, reset=True) |
| 112 | |
| 113 | assert ( |
| 114 | not args.weighted_captions or not args.cache_text_encoder_outputs |
| 115 | ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません" |
| 116 | assert ( |
| 117 | not args.train_text_encoder or not args.cache_text_encoder_outputs |
| 118 | ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" |
| 119 | |
| 120 | if args.block_lr: |
| 121 | block_lrs = [float(lr) for lr in args.block_lr.split(",")] |
| 122 | assert ( |
| 123 | len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR |
| 124 | ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" |
| 125 | else: |
| 126 | block_lrs = None |
| 127 | |
| 128 | cache_latents = args.cache_latents |
| 129 | use_dreambooth_method = args.in_json is None |
| 130 | |
| 131 | if args.seed is not None: |
| 132 | set_seed(args.seed) # 乱数系列を初期化する |
| 133 | |
| 134 | tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) |
| 135 | strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) |
| 136 | tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future |
| 137 | |
| 138 | # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. |
| 139 | if args.cache_latents: |
| 140 | latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( |
| 141 | False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check |
| 142 | ) |
| 143 | strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) |
| 144 | |
| 145 | # データセットを準備する |
| 146 | if args.dataset_class is None: |
| 147 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) |
| 148 | if args.dataset_config is not None: |
| 149 | logger.info(f"Load dataset config from {args.dataset_config}") |
| 150 | user_config = config_util.load_user_config(args.dataset_config) |
| 151 | ignored = ["train_data_dir", "in_json"] |
| 152 | if any(getattr(args, attr) is not None for attr in ignored): |
| 153 | logger.warning( |
| 154 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( |
| 155 | ", ".join(ignored) |
| 156 | ) |
| 157 | ) |
| 158 | else: |
| 159 | if use_dreambooth_method: |
| 160 | logger.info("Using DreamBooth method.") |
| 161 | user_config = { |
| 162 | "datasets": [ |
| 163 | { |
no test coverage detected