MCPcopy
hub / github.com/kohya-ss/sd-scripts / train

Function train

train_db.py:60–533  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

58
59
60def train(args):
61 args_util.verify_training_args(args)
62 accelerator_setup.prepare_dataset_args(args, False)
63 deepspeed_utils.prepare_deepspeed_args(args)
64 setup_logging(args, reset=True)
65
66 cache_latents = args.cache_latents
67
68 if args.seed is not None:
69 set_seed(args.seed) # 乱数系列を初期化する
70
71 tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
72 strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
73
74 # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
75 latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
76 False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
77 )
78 strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
79
80 # データセットを準備する
81 if args.dataset_class is None:
82 blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
83 if args.dataset_config is not None:
84 logger.info(f"Load dataset config from {args.dataset_config}")
85 user_config = config_util.load_user_config(args.dataset_config)
86 ignored = ["train_data_dir", "reg_data_dir"]
87 if any(getattr(args, attr) is not None for attr in ignored):
88 logger.warning(
89 "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
90 ", ".join(ignored)
91 )
92 )
93 else:
94 user_config = {
95 "datasets": [
96 {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
97 ]
98 }
99
100 blueprint = blueprint_generator.generate(user_config, args)
101 train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
102 else:
103 train_dataset_group = dataset_util.load_arbitrary_dataset(args)
104 val_dataset_group = None
105
106 current_epoch = Value("i", 0)
107 current_step = Value("i", 0)
108 ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
109 collator = dataset_util.collator_class(current_epoch, current_step, ds_for_collator)
110
111 if args.no_token_padding:
112 train_dataset_group.disable_token_padding()
113
114 train_dataset_group.verify_bucket_reso_steps(64)
115
116 if args.debug_dataset:
117 dataset_util.debug_dataset(train_dataset_group)

Callers 1

train_db.pyFile · 0.70

Calls 15

generateMethod · 0.95
tokenize_with_weightsMethod · 0.95
encode_tokensMethod · 0.95
addMethod · 0.95
setup_loggingFunction · 0.90
BlueprintGeneratorClass · 0.90
ConfigSanitizerClass · 0.90
clean_memory_on_deviceFunction · 0.90
apply_masked_lossFunction · 0.90
apply_snr_weightFunction · 0.90

Tested by

no test coverage detected