Common enhancement logic for both T2V and I2V modes.
(
processor: Gemma3Processor,
model: Gemma3ForConditionalGeneration,
messages: list,
image: Optional[Image.Image] = None,
max_new_tokens: int = 512,
seed: int = 42,
)
| 400 | |
| 401 | |
| 402 | def _enhance( |
| 403 | processor: Gemma3Processor, |
| 404 | model: Gemma3ForConditionalGeneration, |
| 405 | messages: list, |
| 406 | image: Optional[Image.Image] = None, |
| 407 | max_new_tokens: int = 512, |
| 408 | seed: int = 42, |
| 409 | ) -> str: |
| 410 | """Common enhancement logic for both T2V and I2V modes.""" |
| 411 | if processor is None: |
| 412 | raise ValueError("Processor not loaded - enhancement not available") |
| 413 | |
| 414 | text = processor.tokenizer.apply_chat_template( |
| 415 | messages, tokenize=False, add_generation_prompt=True |
| 416 | ) |
| 417 | |
| 418 | model_inputs = processor( |
| 419 | text=text, |
| 420 | images=image, |
| 421 | return_tensors="pt", |
| 422 | ).to(model.device) |
| 423 | |
| 424 | pad_token_id = ( |
| 425 | processor.tokenizer.pad_token_id |
| 426 | if processor.tokenizer.pad_token_id is not None |
| 427 | else 0 |
| 428 | ) |
| 429 | model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id) |
| 430 | |
| 431 | with ( |
| 432 | torch.inference_mode(), |
| 433 | torch.random.fork_rng(devices=[model.device]), |
| 434 | torch.autocast(device_type=model.device.type, dtype=model.dtype), |
| 435 | ): |
| 436 | torch.manual_seed(seed) |
| 437 | outputs = model.generate( |
| 438 | **model_inputs, |
| 439 | max_new_tokens=max_new_tokens, |
| 440 | do_sample=True, |
| 441 | temperature=0.7, |
| 442 | ) |
| 443 | generated_ids = outputs[0][len(model_inputs.input_ids[0]) :] |
| 444 | enhanced_prompt = processor.tokenizer.decode( |
| 445 | generated_ids, skip_special_tokens=True |
| 446 | ) |
| 447 | |
| 448 | return enhanced_prompt |
| 449 | |
| 450 | |
| 451 | def enhance_t2v( |
no test coverage detected