MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / T5EncoderModel

Class T5EncoderModel

wan/modules/t5.py:477–535  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

475
476
477class T5EncoderModel:
478
479 def __init__(
480 self,
481 text_len,
482 dtype=torch.bfloat16,
483 device=torch.cuda.current_device(),
484 checkpoint_path=None,
485 tokenizer_path=None,
486 shard_fn=None,
487 quant=None,
488 quant_dir=None
489 ):
490 assert quant is None or quant in ("int8", "fp8")
491 self.text_len = text_len
492 self.dtype = dtype
493 self.device = device
494 self.checkpoint_path = checkpoint_path
495 self.tokenizer_path = tokenizer_path
496
497 # init model
498 logging.info(f'loading {checkpoint_path}')
499 if quant is not None:
500 with torch.device('meta'):
501 model = umt5_xxl(
502 encoder_only=True,
503 return_tokenizer=False,
504 dtype=dtype,
505 device=torch.device('meta'))
506 logging.info(f'Loading quantized T5 from {os.path.join(quant_dir, f"t5_{quant}.safetensors")}')
507 model_state_dict = load_file(os.path.join(quant_dir, f"t5_{quant}.safetensors"))
508 with open(os.path.join(quant_dir, f"t5_map_{quant}.json"), "r") as f:
509 quantization_map = json.load(f)
510 requantize(model, model_state_dict, quantization_map, device='cpu')
511 else:
512 model = umt5_xxl(
513 encoder_only=True,
514 return_tokenizer=False,
515 dtype=dtype,
516 device=device).eval().requires_grad_(False)
517 model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
518 self.model = model
519 self.model.eval().requires_grad_(False)
520 if shard_fn is not None:
521 self.model = shard_fn(self.model, sync_module_states=False)
522 else:
523 self.model.to(self.device)
524 # init tokenizer
525 self.tokenizer = HuggingfaceTokenizer(
526 name=tokenizer_path, seq_len=text_len, clean='whitespace')
527
528 def __call__(self, texts, device):
529 ids, mask = self.tokenizer(
530 texts, return_mask=True, add_special_tokens=True)
531 ids = ids.to(device)
532 mask = mask.to(device)
533 seq_lens = mask.gt(0).sum(dim=1).long()
534 context = self.model(ids, mask)

Callers 6

__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
mp_workerMethod · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected