(args)
| 58 | |
| 59 | |
| 60 | def train(args): |
| 61 | args_util.verify_training_args(args) |
| 62 | accelerator_setup.prepare_dataset_args(args, False) |
| 63 | deepspeed_utils.prepare_deepspeed_args(args) |
| 64 | setup_logging(args, reset=True) |
| 65 | |
| 66 | cache_latents = args.cache_latents |
| 67 | |
| 68 | if args.seed is not None: |
| 69 | set_seed(args.seed) # 乱数系列を初期化する |
| 70 | |
| 71 | tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) |
| 72 | strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) |
| 73 | |
| 74 | # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. |
| 75 | latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( |
| 76 | False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check |
| 77 | ) |
| 78 | strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) |
| 79 | |
| 80 | # データセットを準備する |
| 81 | if args.dataset_class is None: |
| 82 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True)) |
| 83 | if args.dataset_config is not None: |
| 84 | logger.info(f"Load dataset config from {args.dataset_config}") |
| 85 | user_config = config_util.load_user_config(args.dataset_config) |
| 86 | ignored = ["train_data_dir", "reg_data_dir"] |
| 87 | if any(getattr(args, attr) is not None for attr in ignored): |
| 88 | logger.warning( |
| 89 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( |
| 90 | ", ".join(ignored) |
| 91 | ) |
| 92 | ) |
| 93 | else: |
| 94 | user_config = { |
| 95 | "datasets": [ |
| 96 | {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} |
| 97 | ] |
| 98 | } |
| 99 | |
| 100 | blueprint = blueprint_generator.generate(user_config, args) |
| 101 | train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) |
| 102 | else: |
| 103 | train_dataset_group = dataset_util.load_arbitrary_dataset(args) |
| 104 | val_dataset_group = None |
| 105 | |
| 106 | current_epoch = Value("i", 0) |
| 107 | current_step = Value("i", 0) |
| 108 | ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None |
| 109 | collator = dataset_util.collator_class(current_epoch, current_step, ds_for_collator) |
| 110 | |
| 111 | if args.no_token_padding: |
| 112 | train_dataset_group.disable_token_padding() |
| 113 | |
| 114 | train_dataset_group.verify_bucket_reso_steps(64) |
| 115 | |
| 116 | if args.debug_dataset: |
| 117 | dataset_util.debug_dataset(train_dataset_group) |
no test coverage detected