Prepare text-related inputs for T2I: LLM encoding. Anima model is also needed for preprocessing
(
args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None
)
| 356 | |
| 357 | |
| 358 | def prepare_text_inputs( |
| 359 | args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None |
| 360 | ) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| 361 | """Prepare text-related inputs for T2I: LLM encoding. Anima model is also needed for preprocessing""" |
| 362 | |
| 363 | # load text encoder: conds_cache holds cached encodings for prompts without padding |
| 364 | conds_cache = {} |
| 365 | text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device |
| 366 | if shared_models is not None: |
| 367 | text_encoder = shared_models.get("text_encoder") |
| 368 | |
| 369 | if "conds_cache" in shared_models: # Use shared cache if available |
| 370 | conds_cache = shared_models["conds_cache"] |
| 371 | |
| 372 | # text_encoder is on device (batched inference) or CPU (interactive inference) |
| 373 | else: # Load if not in shared_models |
| 374 | text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder |
| 375 | text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=text_encoder_device) |
| 376 | text_encoder.eval() |
| 377 | tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() |
| 378 | # Store references so load_target_model can reuse them |
| 379 | |
| 380 | # Store original devices to move back later if they were shared. This does nothing if shared_models is None |
| 381 | text_encoder_original_device = text_encoder.device if text_encoder else None |
| 382 | |
| 383 | # Ensure text_encoder is not None before proceeding |
| 384 | if not text_encoder: |
| 385 | raise ValueError("Text encoder is not loaded properly.") |
| 386 | |
| 387 | # Define a function to move models to device if needed |
| 388 | # This is to avoid moving models if not needed, especially in interactive mode |
| 389 | model_is_moved = False |
| 390 | |
| 391 | def move_models_to_device_if_needed(): |
| 392 | nonlocal model_is_moved |
| 393 | nonlocal shared_models |
| 394 | |
| 395 | if model_is_moved: |
| 396 | return |
| 397 | model_is_moved = True |
| 398 | |
| 399 | logger.info(f"Moving Text Encoder to appropriate device: {text_encoder_device}") |
| 400 | text_encoder.to(text_encoder_device) # If text_encoder_cpu is True, this will be CPU |
| 401 | |
| 402 | logger.info("Encoding prompt with Text Encoder") |
| 403 | |
| 404 | prompt = process_escape(args.prompt) |
| 405 | cache_key = prompt |
| 406 | if cache_key in conds_cache: |
| 407 | embed = conds_cache[cache_key] |
| 408 | else: |
| 409 | move_models_to_device_if_needed() |
| 410 | |
| 411 | tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() |
| 412 | encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() |
| 413 | |
| 414 | with torch.no_grad(): |
| 415 | # embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt) |
no test coverage detected