MCPcopy
hub / github.com/kohya-ss/sd-scripts / process_batch

Method process_batch

train_network.py:371–488  ·  view source on GitHub ↗

Process a batch for the network

(
        self,
        batch,
        text_encoders,
        unet,
        network,
        vae,
        noise_scheduler,
        vae_dtype,
        weight_dtype,
        accelerator,
        args,
        text_encoding_strategy: strategy_base.TextEncodingStrategy,
        tokenize_strategy: strategy_base.TokenizeStrategy,
        is_train=True,
        train_text_encoder=True,
        train_unet=True,
    )

Source from the content-addressed store, hash-verified

369 # endregion
370
371 def process_batch(
372 self,
373 batch,
374 text_encoders,
375 unet,
376 network,
377 vae,
378 noise_scheduler,
379 vae_dtype,
380 weight_dtype,
381 accelerator,
382 args,
383 text_encoding_strategy: strategy_base.TextEncodingStrategy,
384 tokenize_strategy: strategy_base.TokenizeStrategy,
385 is_train=True,
386 train_text_encoder=True,
387 train_unet=True,
388 ) -> torch.Tensor:
389 """
390 Process a batch for the network
391 """
392 with torch.no_grad():
393 if "latents" in batch and batch["latents"] is not None:
394 latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
395 else:
396 # latentに変換
397 if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
398 latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
399 else:
400 chunks = [
401 batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
402 ]
403 list_latents = []
404 for chunk in chunks:
405 with torch.no_grad():
406 chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
407 list_latents.append(chunk)
408 latents = torch.cat(list_latents, dim=0)
409
410 # NaNが含まれていれば警告を表示し0に置き換える
411 if torch.any(torch.isnan(latents)):
412 accelerator.print("NaN found in latents, replacing with zeros")
413 latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
414
415 latents = self.shift_scale_latents(args, latents)
416
417 # Prepare inpainting masked_latents if batch contains masks
418 if batch.get("masks") is not None:
419 masked_latents = self.encode_images_to_latents(
420 args, vae, batch["masked_images"].to(accelerator.device, dtype=vae_dtype)
421 )
422 batch["masked_latents"] = self.shift_scale_latents(args, masked_latents)
423
424 text_encoder_conds = []
425 text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
426 if text_encoder_outputs_list is not None:
427 text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
428

Callers 2

_run_validation_loopMethod · 0.95
trainMethod · 0.95

Calls 11

shift_scale_latentsMethod · 0.95
post_process_lossMethod · 0.95
apply_masked_lossFunction · 0.90
toMethod · 0.80
getMethod · 0.80
tokenize_with_weightsMethod · 0.45
encode_tokensMethod · 0.45

Tested by

no test coverage detected