| 13 | |
| 14 | |
| 15 | class T5Embedder: |
| 16 | available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl'] |
| 17 | bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa |
| 18 | |
| 19 | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, |
| 20 | t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): |
| 21 | self.device = torch.device(device) |
| 22 | self.torch_dtype = torch_dtype or torch.bfloat16 |
| 23 | if t5_model_kwargs is None: |
| 24 | t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} |
| 25 | t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} |
| 26 | |
| 27 | self.use_text_preprocessing = use_text_preprocessing |
| 28 | self.hf_token = hf_token |
| 29 | self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') |
| 30 | self.dir_or_name = dir_or_name |
| 31 | tokenizer_path, path = dir_or_name, dir_or_name |
| 32 | if local_cache: |
| 33 | cache_dir = os.path.join(self.cache_dir, dir_or_name) |
| 34 | tokenizer_path, path = cache_dir, cache_dir |
| 35 | elif dir_or_name in self.available_models: |
| 36 | cache_dir = os.path.join(self.cache_dir, dir_or_name) |
| 37 | for filename in [ |
| 38 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', |
| 39 | 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' |
| 40 | ]: |
| 41 | hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, |
| 42 | force_filename=filename, token=self.hf_token) |
| 43 | tokenizer_path, path = cache_dir, cache_dir |
| 44 | else: |
| 45 | cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') |
| 46 | for filename in [ |
| 47 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', |
| 48 | ]: |
| 49 | hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, |
| 50 | force_filename=filename, token=self.hf_token) |
| 51 | tokenizer_path = cache_dir |
| 52 | |
| 53 | print(tokenizer_path) |
| 54 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| 55 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() |
| 56 | self.model_max_length = model_max_length |
| 57 | |
| 58 | def get_text_embeddings(self, texts): |
| 59 | texts = [self.text_preprocessing(text) for text in texts] |
| 60 | |
| 61 | text_tokens_and_mask = self.tokenizer( |
| 62 | texts, |
| 63 | max_length=self.model_max_length, |
| 64 | padding='max_length', |
| 65 | truncation=True, |
| 66 | return_attention_mask=True, |
| 67 | add_special_tokens=True, |
| 68 | return_tensors='pt' |
| 69 | ) |
| 70 | |
| 71 | text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] |
| 72 | text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] |