MCPcopy
hub / github.com/kohya-ss/sd-scripts / train

Method train

train_network.py:898–1863  ·  view source on GitHub ↗
(self, args)

Source from the content-addressed store, hash-verified

896 os.remove(old_ckpt_file)
897
898 def train(self, args):
899 session_id = random.randint(0, 2**32)
900 training_started_at = time.time()
901 args_util.verify_training_args(args)
902 accelerator_setup.prepare_dataset_args(args, True)
903 deepspeed_utils.prepare_deepspeed_args(args)
904 setup_logging(args, reset=True)
905
906 cache_latents = args.cache_latents
907 use_dreambooth_method = args.in_json is None
908 use_user_config = args.dataset_config is not None
909
910 if args.seed is None:
911 args.seed = random.randint(0, 2**32)
912 set_seed(args.seed)
913
914 tokenize_strategy = self.get_tokenize_strategy(args)
915 strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
916 tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
917
918 # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
919 latents_caching_strategy = self.get_latents_caching_strategy(args)
920 strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
921
922 # データセットを準備する
923 if args.dataset_class is None:
924 blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
925 if use_user_config:
926 logger.info(f"Loading dataset config from {args.dataset_config}")
927 user_config = config_util.load_user_config(args.dataset_config)
928 ignored = ["train_data_dir", "reg_data_dir", "in_json"]
929 if any(getattr(args, attr) is not None for attr in ignored):
930 logger.warning(
931 "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
932 ", ".join(ignored)
933 )
934 )
935 else:
936 if use_dreambooth_method:
937 logger.info("Using DreamBooth method.")
938 user_config = {
939 "datasets": [
940 {
941 "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
942 args.train_data_dir, args.reg_data_dir
943 )
944 }
945 ]
946 }
947 else:
948 logger.info("Training with captions.")
949 user_config = {
950 "datasets": [
951 {
952 "subsets": [
953 {
954 "image_dir": args.train_data_dir,
955 "metadata_file": args.in_json,

Callers 15

trainFunction · 0.45
sample_image_inferenceFunction · 0.45
trainFunction · 0.45
mainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45
mainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45

Calls 15

get_tokenize_strategyMethod · 0.95
get_tokenizersMethod · 0.95
generateMethod · 0.95
assert_extra_argsMethod · 0.95
cast_vaeMethod · 0.95
load_target_modelMethod · 0.95
load_unet_lazilyMethod · 0.95
post_process_networkMethod · 0.95

Tested by 1

mainFunction · 0.36