MCPcopy
hub / github.com/Robbyant/lingbot-world / T5EncoderModel

Class T5EncoderModel

wan/modules/t5.py:470–511  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

468
469
470class T5EncoderModel:
471
472 def __init__(
473 self,
474 text_len,
475 dtype=torch.bfloat16,
476 device=torch.cuda.current_device(),
477 checkpoint_path=None,
478 tokenizer_path=None,
479 shard_fn=None,
480 ):
481 self.text_len = text_len
482 self.dtype = dtype
483 self.device = device
484 self.checkpoint_path = checkpoint_path
485 self.tokenizer_path = tokenizer_path
486
487 # init model
488 model = umt5_xxl(
489 encoder_only=True,
490 return_tokenizer=False,
491 dtype=dtype,
492 device=device).eval().requires_grad_(False)
493 logging.info(f'loading {checkpoint_path}')
494 model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
495 self.model = model
496 if shard_fn is not None:
497 self.model = shard_fn(self.model, sync_module_states=False)
498 else:
499 self.model.to(self.device)
500 # init tokenizer
501 self.tokenizer = HuggingfaceTokenizer(
502 name=tokenizer_path, seq_len=text_len, clean='whitespace')
503
504 def __call__(self, texts, device):
505 ids, mask = self.tokenizer(
506 texts, return_mask=True, add_special_tokens=True)
507 ids = ids.to(device)
508 mask = mask.to(device)
509 seq_lens = mask.gt(0).sum(dim=1).long()
510 context = self.model(ids, mask)
511 return [u[:v] for u, v in zip(context, seq_lens)]

Callers 2

__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected