MCPcopy
hub / github.com/FoundationVision/LlamaGen / T5Embedder

Class T5Embedder

language/t5.py:15–201  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

13
14
15class T5Embedder:
16 available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl']
17 bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
18
19 def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
20 t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
21 self.device = torch.device(device)
22 self.torch_dtype = torch_dtype or torch.bfloat16
23 if t5_model_kwargs is None:
24 t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
25 t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
26
27 self.use_text_preprocessing = use_text_preprocessing
28 self.hf_token = hf_token
29 self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
30 self.dir_or_name = dir_or_name
31 tokenizer_path, path = dir_or_name, dir_or_name
32 if local_cache:
33 cache_dir = os.path.join(self.cache_dir, dir_or_name)
34 tokenizer_path, path = cache_dir, cache_dir
35 elif dir_or_name in self.available_models:
36 cache_dir = os.path.join(self.cache_dir, dir_or_name)
37 for filename in [
38 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
39 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
40 ]:
41 hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
42 force_filename=filename, token=self.hf_token)
43 tokenizer_path, path = cache_dir, cache_dir
44 else:
45 cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
46 for filename in [
47 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
48 ]:
49 hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
50 force_filename=filename, token=self.hf_token)
51 tokenizer_path = cache_dir
52
53 print(tokenizer_path)
54 self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
55 self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
56 self.model_max_length = model_max_length
57
58 def get_text_embeddings(self, texts):
59 texts = [self.text_preprocessing(text) for text in texts]
60
61 text_tokens_and_mask = self.tokenizer(
62 texts,
63 max_length=self.model_max_length,
64 padding='max_length',
65 truncation=True,
66 return_attention_mask=True,
67 add_special_tokens=True,
68 return_tensors='pt'
69 )
70
71 text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
72 text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']

Callers 3

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected