MCPcopy
hub / github.com/deep-floyd/IF / get_text_embeddings

Method get_text_embeddings

deepfloyd_if/modules/t5.py:79–100  ·  view source on GitHub ↗
(self, texts)

Source from the content-addressed store, hash-verified

77 self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
78
79 def get_text_embeddings(self, texts):
80 texts = [self.text_preprocessing(text) for text in texts]
81
82 text_tokens_and_mask = self.tokenizer(
83 texts,
84 max_length=77,
85 padding='max_length',
86 truncation=True,
87 return_attention_mask=True,
88 add_special_tokens=True,
89 return_tensors='pt'
90 )
91 text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
92 text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
93
94 with torch.no_grad():
95 text_encoder_embs = self.model(
96 input_ids=text_tokens_and_mask['input_ids'].to(self.device),
97 attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
98 )['last_hidden_state'].detach()
99
100 return text_encoder_embs
101
102 def text_preprocessing(self, text):
103 if self.use_text_preprocessing:

Callers 4

style_transferFunction · 0.80
super_resolutionFunction · 0.80
inpaintingFunction · 0.80
dreamFunction · 0.80

Calls 1

text_preprocessingMethod · 0.95

Tested by

no test coverage detected