MCPcopy
hub / github.com/kohya-ss/sd-scripts / train

Function train

anima_train.py:50–716  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

48
49
50def train(args):
51 args_util.verify_training_args(args)
52 accelerator_setup.prepare_dataset_args(args, True)
53 deepspeed_utils.prepare_deepspeed_args(args)
54 setup_logging(args, reset=True)
55
56 flux_train_utils.log_timestep_sampling_info(args)
57
58 # backward compatibility
59 if not args.skip_cache_check:
60 args.skip_cache_check = args.skip_latents_validity_check
61
62 if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
63 logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
64 args.cache_text_encoder_outputs = True
65
66 if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
67 logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
68 args.gradient_checkpointing = True
69
70 if args.unsloth_offload_checkpointing:
71 if not args.gradient_checkpointing:
72 logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
73 args.gradient_checkpointing = True
74 assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
75
76 assert (
77 args.blocks_to_swap is None or args.blocks_to_swap == 0
78 ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
79
80 assert (
81 args.blocks_to_swap is None or args.blocks_to_swap == 0
82 ) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
83
84 cache_latents = args.cache_latents
85 use_dreambooth_method = args.in_json is None
86
87 if args.seed is not None:
88 set_seed(args.seed)
89
90 # prepare caching strategy: must be set before preparing dataset
91 if args.cache_latents:
92 latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy(
93 args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
94 )
95 strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
96
97 # prepare dataset
98 if args.dataset_class is None:
99 blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
100 if args.dataset_config is not None:
101 logger.info(f"Load dataset config from {args.dataset_config}")
102 user_config = config_util.load_user_config(args.dataset_config)
103 ignored = ["train_data_dir", "in_json"]
104 if any(getattr(args, attr) is not None for attr in ignored):
105 logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
106 else:
107 if use_dreambooth_method:

Callers 1

anima_train.pyFile · 0.70

Calls 15

generateMethod · 0.95
tokenizeMethod · 0.95
encode_tokensMethod · 0.95
addMethod · 0.95
setup_loggingFunction · 0.90
BlueprintGeneratorClass · 0.90
ConfigSanitizerClass · 0.90
clean_memory_on_deviceFunction · 0.90
apply_masked_lossFunction · 0.90
toMethod · 0.80

Tested by

no test coverage detected