MCPcopy Index your code
hub / github.com/Wan-Video/Wan2.2 / T5EncoderModel

Class T5EncoderModel

wan/modules/t5.py:472–513  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

470
471
472class T5EncoderModel:
473
474 def __init__(
475 self,
476 text_len,
477 dtype=torch.bfloat16,
478 device=torch.cuda.current_device(),
479 checkpoint_path=None,
480 tokenizer_path=None,
481 shard_fn=None,
482 ):
483 self.text_len = text_len
484 self.dtype = dtype
485 self.device = device
486 self.checkpoint_path = checkpoint_path
487 self.tokenizer_path = tokenizer_path
488
489 # init model
490 model = umt5_xxl(
491 encoder_only=True,
492 return_tokenizer=False,
493 dtype=dtype,
494 device=device).eval().requires_grad_(False)
495 logging.info(f'loading {checkpoint_path}')
496 model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
497 self.model = model
498 if shard_fn is not None:
499 self.model = shard_fn(self.model, sync_module_states=False)
500 else:
501 self.model.to(self.device)
502 # init tokenizer
503 self.tokenizer = HuggingfaceTokenizer(
504 name=tokenizer_path, seq_len=text_len, clean='whitespace')
505
506 def __call__(self, texts, device):
507 ids, mask = self.tokenizer(
508 texts, return_mask=True, add_special_tokens=True)
509 ids = ids.to(device)
510 mask = mask.to(device)
511 seq_lens = mask.gt(0).sum(dim=1).long()
512 context = self.model(ids, mask)
513 return [u[:v] for u, v in zip(context, seq_lens)]

Callers 5

__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected