MCPcopy
hub / github.com/lm-sys/FastChat / get_embeddings

Method get_embeddings

fastchat/serve/model_worker.py:185–300  ·  view source on GitHub ↗
(self, params)

Source from the content-addressed store, hash-verified

183
184 @torch.inference_mode()
185 def get_embeddings(self, params):
186 self.call_ct += 1
187
188 try:
189 tokenizer = self.tokenizer
190 ret = {"embedding": [], "token_num": 0}
191
192 model_type_dict = {
193 "is_llama": "llama" in str(type(self.model)),
194 "is_t5": "t5" in str(type(self.model)),
195 "is_chatglm": "chatglm" in str(type(self.model)),
196 "is_bert": "bert" in str(type(self.model)),
197 "is_robert": "robert" in str(type(self.model)),
198 }
199
200 if self.embed_in_truncate:
201 encoding = tokenizer.batch_encode_plus(
202 params["input"],
203 padding=True,
204 truncation="longest_first",
205 return_tensors="pt",
206 max_length=self.context_len,
207 )
208 else:
209 encoding = tokenizer.batch_encode_plus(
210 params["input"], padding=True, return_tensors="pt"
211 )
212 input_ids = encoding["input_ids"].to(self.device)
213 attention_mask = input_ids != tokenizer.pad_token_id
214
215 base64_encode = params.get("encoding_format", None)
216
217 if self.embed_in_truncate:
218 embedding, token_num = self.__process_embed_chunk(
219 input_ids, attention_mask, **model_type_dict
220 )
221 if (
222 not hasattr(self.model, "use_cls_pooling")
223 or not self.model.use_cls_pooling
224 ):
225 embedding = embedding / token_num
226 normalized_embeddings = F.normalize(embedding, p=2, dim=1)
227 ret["token_num"] = token_num
228 else:
229 all_embeddings = []
230 all_token_num = 0
231 for i in range(0, input_ids.size(1), self.context_len):
232 chunk_input_ids = input_ids[:, i : i + self.context_len]
233 chunk_attention_mask = attention_mask[:, i : i + self.context_len]
234
235 # add cls token and mask to get cls embedding
236 if (
237 hasattr(self.model, "use_cls_pooling")
238 and self.model.use_cls_pooling
239 ):
240 cls_tokens = (
241 torch.zeros(
242 (chunk_input_ids.size(0), 1),

Callers 1

api_get_embeddingsFunction · 0.45

Calls 3

__process_embed_chunkMethod · 0.95
__encode_base64Method · 0.95
toMethod · 0.80

Tested by

no test coverage detected