MCPcopy
hub / github.com/kohya-ss/sd-scripts / train

Function train

lumina_train.py:64–913  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

62
63
64def train(args):
65 args_util.verify_training_args(args)
66 accelerator_setup.prepare_dataset_args(args, True)
67 # sdxl_train_util.verify_sdxl_training_args(args)
68 deepspeed_utils.prepare_deepspeed_args(args)
69 setup_logging(args, reset=True)
70
71 # temporary: backward compatibility for deprecated options. remove in the future
72 if not args.skip_cache_check:
73 args.skip_cache_check = args.skip_latents_validity_check
74
75 # assert (
76 # not args.weighted_captions
77 # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
78 if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
79 logger.warning(
80 "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
81 )
82 args.cache_text_encoder_outputs = True
83
84 if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
85 logger.warning(
86 "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
87 )
88 args.gradient_checkpointing = True
89
90 # assert (
91 # args.blocks_to_swap is None or args.blocks_to_swap == 0
92 # ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
93
94 cache_latents = args.cache_latents
95 use_dreambooth_method = args.in_json is None
96
97 if args.seed is not None:
98 set_seed(args.seed) # 乱数系列を初期化する
99
100 # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
101 if args.cache_latents:
102 latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy(
103 args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
104 )
105 strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
106
107 # データセットを準備する
108 if args.dataset_class is None:
109 blueprint_generator = BlueprintGenerator(
110 ConfigSanitizer(True, True, args.masked_loss, True)
111 )
112 if args.dataset_config is not None:
113 logger.info(f"Load dataset config from {args.dataset_config}")
114 user_config = config_util.load_user_config(args.dataset_config)
115 ignored = ["train_data_dir", "in_json"]
116 if any(getattr(args, attr) is not None for attr in ignored):
117 logger.warning(
118 "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
119 ", ".join(ignored)
120 )
121 )

Callers 1

lumina_train.pyFile · 0.70

Calls 15

generateMethod · 0.95
tokenizeMethod · 0.95
encode_tokensMethod · 0.95
addMethod · 0.95
setup_loggingFunction · 0.90
BlueprintGeneratorClass · 0.90
ConfigSanitizerClass · 0.90
clean_memory_on_deviceFunction · 0.90
apply_masked_lossFunction · 0.90
toMethod · 0.80
requires_grad_Method · 0.80

Tested by

no test coverage detected