| 475 | |
| 476 | |
| 477 | class T5EncoderModel: |
| 478 | |
| 479 | def __init__( |
| 480 | self, |
| 481 | text_len, |
| 482 | dtype=torch.bfloat16, |
| 483 | device=torch.cuda.current_device(), |
| 484 | checkpoint_path=None, |
| 485 | tokenizer_path=None, |
| 486 | shard_fn=None, |
| 487 | quant=None, |
| 488 | quant_dir=None |
| 489 | ): |
| 490 | assert quant is None or quant in ("int8", "fp8") |
| 491 | self.text_len = text_len |
| 492 | self.dtype = dtype |
| 493 | self.device = device |
| 494 | self.checkpoint_path = checkpoint_path |
| 495 | self.tokenizer_path = tokenizer_path |
| 496 | |
| 497 | # init model |
| 498 | logging.info(f'loading {checkpoint_path}') |
| 499 | if quant is not None: |
| 500 | with torch.device('meta'): |
| 501 | model = umt5_xxl( |
| 502 | encoder_only=True, |
| 503 | return_tokenizer=False, |
| 504 | dtype=dtype, |
| 505 | device=torch.device('meta')) |
| 506 | logging.info(f'Loading quantized T5 from {os.path.join(quant_dir, f"t5_{quant}.safetensors")}') |
| 507 | model_state_dict = load_file(os.path.join(quant_dir, f"t5_{quant}.safetensors")) |
| 508 | with open(os.path.join(quant_dir, f"t5_map_{quant}.json"), "r") as f: |
| 509 | quantization_map = json.load(f) |
| 510 | requantize(model, model_state_dict, quantization_map, device='cpu') |
| 511 | else: |
| 512 | model = umt5_xxl( |
| 513 | encoder_only=True, |
| 514 | return_tokenizer=False, |
| 515 | dtype=dtype, |
| 516 | device=device).eval().requires_grad_(False) |
| 517 | model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) |
| 518 | self.model = model |
| 519 | self.model.eval().requires_grad_(False) |
| 520 | if shard_fn is not None: |
| 521 | self.model = shard_fn(self.model, sync_module_states=False) |
| 522 | else: |
| 523 | self.model.to(self.device) |
| 524 | # init tokenizer |
| 525 | self.tokenizer = HuggingfaceTokenizer( |
| 526 | name=tokenizer_path, seq_len=text_len, clean='whitespace') |
| 527 | |
| 528 | def __call__(self, texts, device): |
| 529 | ids, mask = self.tokenizer( |
| 530 | texts, return_mask=True, add_special_tokens=True) |
| 531 | ids = ids.to(device) |
| 532 | mask = mask.to(device) |
| 533 | seq_lens = mask.gt(0).sum(dim=1).long() |
| 534 | context = self.model(ids, mask) |