MCPcopy
hub / github.com/kohya-ss/sd-scripts / train

Method train

train_textual_inversion.py:193–788  ·  view source on GitHub ↗
(self, args)

Source from the content-addressed store, hash-verified

191 return [emb]
192
193 def train(self, args):
194 if args.output_name is None:
195 args.output_name = args.token_string
196 use_template = args.use_object_template or args.use_style_template
197
198 args_util.verify_training_args(args)
199 accelerator_setup.prepare_dataset_args(args, True)
200 setup_logging(args, reset=True)
201
202 cache_latents = args.cache_latents
203
204 if args.seed is not None:
205 set_seed(args.seed)
206
207 tokenize_strategy = self.get_tokenize_strategy(args)
208 strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
209 tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
210
211 # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
212 latents_caching_strategy = self.get_latents_caching_strategy(args)
213 strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
214
215 # acceleratorを準備する
216 logger.info("prepare accelerator")
217 accelerator = accelerator_setup.prepare_accelerator(args)
218
219 # mixed precisionに対応した型を用意しておき適宜castする
220 weight_dtype, save_dtype = accelerator_setup.prepare_dtype(args)
221 vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
222
223 # モデルを読み込む
224 model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
225
226 # Convert the init_word to token_id
227 init_token_ids_list = []
228 if args.init_word is not None:
229 for i, tokenizer in enumerate(tokenizers):
230 init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
231 if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
232 accelerator.print(
233 f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / "
234 + f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}"
235 )
236 init_token_ids_list.append(init_token_ids)
237 else:
238 init_token_ids_list = [None] * len(tokenizers)
239
240 # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
241 # token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される
242 # add new word to tokenizer, count is num_vectors_per_token
243 # if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
244
245 self.assert_token_string(args.token_string, tokenizers)
246
247 token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
248 token_ids_list = []
249 token_embeds_list = []
250 for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):

Callers 1

Calls 15

get_tokenize_strategyMethod · 0.95
get_tokenizersMethod · 0.95
load_target_modelMethod · 0.95
assert_token_stringMethod · 0.95
load_weightsMethod · 0.95
generateMethod · 0.95
assert_extra_argsMethod · 0.95
sample_imagesMethod · 0.95
call_unetMethod · 0.95

Tested by

no test coverage detected