MCPcopy
hub / github.com/kohya-ss/sd-scripts / prepare_text_inputs

Function prepare_text_inputs

anima_minimal_inference.py:358–469  ·  view source on GitHub ↗

Prepare text-related inputs for T2I: LLM encoding. Anima model is also needed for preprocessing

(
    args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None
)

Source from the content-addressed store, hash-verified

356
357
358def prepare_text_inputs(
359 args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None
360) -> Tuple[Dict[str, Any], Dict[str, Any]]:
361 """Prepare text-related inputs for T2I: LLM encoding. Anima model is also needed for preprocessing"""
362
363 # load text encoder: conds_cache holds cached encodings for prompts without padding
364 conds_cache = {}
365 text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
366 if shared_models is not None:
367 text_encoder = shared_models.get("text_encoder")
368
369 if "conds_cache" in shared_models: # Use shared cache if available
370 conds_cache = shared_models["conds_cache"]
371
372 # text_encoder is on device (batched inference) or CPU (interactive inference)
373 else: # Load if not in shared_models
374 text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
375 text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=text_encoder_device)
376 text_encoder.eval()
377 tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
378 # Store references so load_target_model can reuse them
379
380 # Store original devices to move back later if they were shared. This does nothing if shared_models is None
381 text_encoder_original_device = text_encoder.device if text_encoder else None
382
383 # Ensure text_encoder is not None before proceeding
384 if not text_encoder:
385 raise ValueError("Text encoder is not loaded properly.")
386
387 # Define a function to move models to device if needed
388 # This is to avoid moving models if not needed, especially in interactive mode
389 model_is_moved = False
390
391 def move_models_to_device_if_needed():
392 nonlocal model_is_moved
393 nonlocal shared_models
394
395 if model_is_moved:
396 return
397 model_is_moved = True
398
399 logger.info(f"Moving Text Encoder to appropriate device: {text_encoder_device}")
400 text_encoder.to(text_encoder_device) # If text_encoder_cpu is True, this will be CPU
401
402 logger.info("Encoding prompt with Text Encoder")
403
404 prompt = process_escape(args.prompt)
405 cache_key = prompt
406 if cache_key in conds_cache:
407 embed = conds_cache[cache_key]
408 else:
409 move_models_to_device_if_needed()
410
411 tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
412 encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
413
414 with torch.no_grad():
415 # embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt)

Callers 2

generateFunction · 0.70
process_batch_promptsFunction · 0.70

Calls 11

clean_memory_on_deviceFunction · 0.90
load_text_encoderFunction · 0.85
process_escapeFunction · 0.85
getMethod · 0.80
toMethod · 0.80
deviceMethod · 0.45
get_strategyMethod · 0.45
tokenizeMethod · 0.45
encode_tokensMethod · 0.45

Tested by

no test coverage detected