MCPcopy
hub / github.com/kohya-ss/sd-scripts / train

Function train

sdxl_train.py:106–908  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

104
105
106def train(args):
107 args_util.verify_training_args(args)
108 accelerator_setup.prepare_dataset_args(args, True)
109 sdxl_train_util.verify_sdxl_training_args(args)
110 deepspeed_utils.prepare_deepspeed_args(args)
111 setup_logging(args, reset=True)
112
113 assert (
114 not args.weighted_captions or not args.cache_text_encoder_outputs
115 ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません"
116 assert (
117 not args.train_text_encoder or not args.cache_text_encoder_outputs
118 ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
119
120 if args.block_lr:
121 block_lrs = [float(lr) for lr in args.block_lr.split(",")]
122 assert (
123 len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
124 ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
125 else:
126 block_lrs = None
127
128 cache_latents = args.cache_latents
129 use_dreambooth_method = args.in_json is None
130
131 if args.seed is not None:
132 set_seed(args.seed) # 乱数系列を初期化する
133
134 tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
135 strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
136 tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future
137
138 # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
139 if args.cache_latents:
140 latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
141 False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
142 )
143 strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
144
145 # データセットを準備する
146 if args.dataset_class is None:
147 blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
148 if args.dataset_config is not None:
149 logger.info(f"Load dataset config from {args.dataset_config}")
150 user_config = config_util.load_user_config(args.dataset_config)
151 ignored = ["train_data_dir", "in_json"]
152 if any(getattr(args, attr) is not None for attr in ignored):
153 logger.warning(
154 "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
155 ", ".join(ignored)
156 )
157 )
158 else:
159 if use_dreambooth_method:
160 logger.info("Using DreamBooth method.")
161 user_config = {
162 "datasets": [
163 {

Callers 1

sdxl_train.pyFile · 0.70

Calls 15

generateMethod · 0.95
tokenize_with_weightsMethod · 0.95
encode_tokensMethod · 0.95
addMethod · 0.95
setup_loggingFunction · 0.90
BlueprintGeneratorClass · 0.90
ConfigSanitizerClass · 0.90
clean_memory_on_deviceFunction · 0.90
apply_masked_lossFunction · 0.90
apply_snr_weightFunction · 0.90

Tested by

no test coverage detected