Process a batch for the network
(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
)
| 369 | # endregion |
| 370 | |
| 371 | def process_batch( |
| 372 | self, |
| 373 | batch, |
| 374 | text_encoders, |
| 375 | unet, |
| 376 | network, |
| 377 | vae, |
| 378 | noise_scheduler, |
| 379 | vae_dtype, |
| 380 | weight_dtype, |
| 381 | accelerator, |
| 382 | args, |
| 383 | text_encoding_strategy: strategy_base.TextEncodingStrategy, |
| 384 | tokenize_strategy: strategy_base.TokenizeStrategy, |
| 385 | is_train=True, |
| 386 | train_text_encoder=True, |
| 387 | train_unet=True, |
| 388 | ) -> torch.Tensor: |
| 389 | """ |
| 390 | Process a batch for the network |
| 391 | """ |
| 392 | with torch.no_grad(): |
| 393 | if "latents" in batch and batch["latents"] is not None: |
| 394 | latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) |
| 395 | else: |
| 396 | # latentに変換 |
| 397 | if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: |
| 398 | latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) |
| 399 | else: |
| 400 | chunks = [ |
| 401 | batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) |
| 402 | ] |
| 403 | list_latents = [] |
| 404 | for chunk in chunks: |
| 405 | with torch.no_grad(): |
| 406 | chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) |
| 407 | list_latents.append(chunk) |
| 408 | latents = torch.cat(list_latents, dim=0) |
| 409 | |
| 410 | # NaNが含まれていれば警告を表示し0に置き換える |
| 411 | if torch.any(torch.isnan(latents)): |
| 412 | accelerator.print("NaN found in latents, replacing with zeros") |
| 413 | latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents)) |
| 414 | |
| 415 | latents = self.shift_scale_latents(args, latents) |
| 416 | |
| 417 | # Prepare inpainting masked_latents if batch contains masks |
| 418 | if batch.get("masks") is not None: |
| 419 | masked_latents = self.encode_images_to_latents( |
| 420 | args, vae, batch["masked_images"].to(accelerator.device, dtype=vae_dtype) |
| 421 | ) |
| 422 | batch["masked_latents"] = self.shift_scale_latents(args, masked_latents) |
| 423 | |
| 424 | text_encoder_conds = [] |
| 425 | text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) |
| 426 | if text_encoder_outputs_list is not None: |
| 427 | text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs |
| 428 |
no test coverage detected