| 470 | |
| 471 | |
| 472 | class T5EncoderModel: |
| 473 | |
| 474 | def __init__( |
| 475 | self, |
| 476 | text_len, |
| 477 | dtype=torch.bfloat16, |
| 478 | device=torch.cuda.current_device(), |
| 479 | checkpoint_path=None, |
| 480 | tokenizer_path=None, |
| 481 | shard_fn=None, |
| 482 | ): |
| 483 | self.text_len = text_len |
| 484 | self.dtype = dtype |
| 485 | self.device = device |
| 486 | self.checkpoint_path = checkpoint_path |
| 487 | self.tokenizer_path = tokenizer_path |
| 488 | |
| 489 | # init model |
| 490 | model = umt5_xxl( |
| 491 | encoder_only=True, |
| 492 | return_tokenizer=False, |
| 493 | dtype=dtype, |
| 494 | device=device).eval().requires_grad_(False) |
| 495 | logging.info(f'loading {checkpoint_path}') |
| 496 | model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) |
| 497 | self.model = model |
| 498 | if shard_fn is not None: |
| 499 | self.model = shard_fn(self.model, sync_module_states=False) |
| 500 | else: |
| 501 | self.model.to(self.device) |
| 502 | # init tokenizer |
| 503 | self.tokenizer = HuggingfaceTokenizer( |
| 504 | name=tokenizer_path, seq_len=text_len, clean='whitespace') |
| 505 | |
| 506 | def __call__(self, texts, device): |
| 507 | ids, mask = self.tokenizer( |
| 508 | texts, return_mask=True, add_special_tokens=True) |
| 509 | ids = ids.to(device) |
| 510 | mask = mask.to(device) |
| 511 | seq_lens = mask.gt(0).sum(dim=1).long() |
| 512 | context = self.model(ids, mask) |
| 513 | return [u[:v] for u, v in zip(context, seq_lens)] |