| 896 | os.remove(old_ckpt_file) |
| 897 | |
| 898 | def train(self, args): |
| 899 | session_id = random.randint(0, 2**32) |
| 900 | training_started_at = time.time() |
| 901 | args_util.verify_training_args(args) |
| 902 | accelerator_setup.prepare_dataset_args(args, True) |
| 903 | deepspeed_utils.prepare_deepspeed_args(args) |
| 904 | setup_logging(args, reset=True) |
| 905 | |
| 906 | cache_latents = args.cache_latents |
| 907 | use_dreambooth_method = args.in_json is None |
| 908 | use_user_config = args.dataset_config is not None |
| 909 | |
| 910 | if args.seed is None: |
| 911 | args.seed = random.randint(0, 2**32) |
| 912 | set_seed(args.seed) |
| 913 | |
| 914 | tokenize_strategy = self.get_tokenize_strategy(args) |
| 915 | strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) |
| 916 | tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored |
| 917 | |
| 918 | # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. |
| 919 | latents_caching_strategy = self.get_latents_caching_strategy(args) |
| 920 | strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) |
| 921 | |
| 922 | # データセットを準備する |
| 923 | if args.dataset_class is None: |
| 924 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) |
| 925 | if use_user_config: |
| 926 | logger.info(f"Loading dataset config from {args.dataset_config}") |
| 927 | user_config = config_util.load_user_config(args.dataset_config) |
| 928 | ignored = ["train_data_dir", "reg_data_dir", "in_json"] |
| 929 | if any(getattr(args, attr) is not None for attr in ignored): |
| 930 | logger.warning( |
| 931 | "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( |
| 932 | ", ".join(ignored) |
| 933 | ) |
| 934 | ) |
| 935 | else: |
| 936 | if use_dreambooth_method: |
| 937 | logger.info("Using DreamBooth method.") |
| 938 | user_config = { |
| 939 | "datasets": [ |
| 940 | { |
| 941 | "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( |
| 942 | args.train_data_dir, args.reg_data_dir |
| 943 | ) |
| 944 | } |
| 945 | ] |
| 946 | } |
| 947 | else: |
| 948 | logger.info("Training with captions.") |
| 949 | user_config = { |
| 950 | "datasets": [ |
| 951 | { |
| 952 | "subsets": [ |
| 953 | { |
| 954 | "image_dir": args.train_data_dir, |
| 955 | "metadata_file": args.in_json, |