MCPcopy
hub / github.com/kohya-ss/sd-scripts / train

Function train

sd3_train.py:65–991  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

63
64
65def train(args):
66 args_util.verify_training_args(args)
67 accelerator_setup.prepare_dataset_args(args, True)
68 # sdxl_train_util.verify_sdxl_training_args(args)
69 deepspeed_utils.prepare_deepspeed_args(args)
70 setup_logging(args, reset=True)
71
72 # temporary: backward compatibility for deprecated options. remove in the future
73 if not args.skip_cache_check:
74 args.skip_cache_check = args.skip_latents_validity_check
75
76 # assert (
77 # not args.weighted_captions
78 # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
79 # assert (
80 # not args.train_text_encoder or not args.cache_text_encoder_outputs
81 # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
82 if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
83 logger.warning(
84 "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
85 )
86 args.cache_text_encoder_outputs = True
87
88 assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
89 "when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
90 + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)"
91 )
92
93 if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs:
94 logger.warning(
95 "use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled."
96 + " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります"
97 )
98 args.cache_text_encoder_outputs = True
99
100 if args.train_t5xxl:
101 assert (
102 args.train_text_encoder
103 ), "when training T5XXL, text encoder (CLIP-L/G) must be trained / T5XXLを学習するときはtext encoder (CLIP-L/G)も学習する必要があります"
104 assert (
105 not args.cache_text_encoder_outputs
106 ), "when training T5XXL, t5xxl output must not be cached / T5XXLを学習するときはt5xxlの出力をキャッシュできません"
107
108 cache_latents = args.cache_latents
109 use_dreambooth_method = args.in_json is None
110
111 if args.seed is not None:
112 set_seed(args.seed) # 乱数系列を初期化する
113
114 # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
115 if args.cache_latents:
116 latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
117 args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
118 )
119 strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
120
121 # データセットを準備する
122 if args.dataset_class is None:

Callers 1

sd3_train.pyFile · 0.70

Calls 15

generateMethod · 0.95
tokenizeMethod · 0.95
encode_tokensMethod · 0.95
concat_encodingsMethod · 0.95
addMethod · 0.95
setup_loggingFunction · 0.90
BlueprintGeneratorClass · 0.90
ConfigSanitizerClass · 0.90
match_mixed_precisionFunction · 0.90
load_safetensorsFunction · 0.90
clean_memory_on_deviceFunction · 0.90

Tested by

no test coverage detected