(self, args)
| 191 | return [emb] |
| 192 | |
| 193 | def train(self, args): |
| 194 | if args.output_name is None: |
| 195 | args.output_name = args.token_string |
| 196 | use_template = args.use_object_template or args.use_style_template |
| 197 | |
| 198 | args_util.verify_training_args(args) |
| 199 | accelerator_setup.prepare_dataset_args(args, True) |
| 200 | setup_logging(args, reset=True) |
| 201 | |
| 202 | cache_latents = args.cache_latents |
| 203 | |
| 204 | if args.seed is not None: |
| 205 | set_seed(args.seed) |
| 206 | |
| 207 | tokenize_strategy = self.get_tokenize_strategy(args) |
| 208 | strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) |
| 209 | tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored |
| 210 | |
| 211 | # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. |
| 212 | latents_caching_strategy = self.get_latents_caching_strategy(args) |
| 213 | strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) |
| 214 | |
| 215 | # acceleratorを準備する |
| 216 | logger.info("prepare accelerator") |
| 217 | accelerator = accelerator_setup.prepare_accelerator(args) |
| 218 | |
| 219 | # mixed precisionに対応した型を用意しておき適宜castする |
| 220 | weight_dtype, save_dtype = accelerator_setup.prepare_dtype(args) |
| 221 | vae_dtype = torch.float32 if args.no_half_vae else weight_dtype |
| 222 | |
| 223 | # モデルを読み込む |
| 224 | model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator) |
| 225 | |
| 226 | # Convert the init_word to token_id |
| 227 | init_token_ids_list = [] |
| 228 | if args.init_word is not None: |
| 229 | for i, tokenizer in enumerate(tokenizers): |
| 230 | init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) |
| 231 | if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: |
| 232 | accelerator.print( |
| 233 | f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / " |
| 234 | + f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}" |
| 235 | ) |
| 236 | init_token_ids_list.append(init_token_ids) |
| 237 | else: |
| 238 | init_token_ids_list = [None] * len(tokenizers) |
| 239 | |
| 240 | # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token |
| 241 | # token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される |
| 242 | # add new word to tokenizer, count is num_vectors_per_token |
| 243 | # if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added |
| 244 | |
| 245 | self.assert_token_string(args.token_string, tokenizers) |
| 246 | |
| 247 | token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] |
| 248 | token_ids_list = [] |
| 249 | token_embeds_list = [] |
| 250 | for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)): |
no test coverage detected