MCPcopy
hub / github.com/lm-sys/FastChat / __process_embed_chunk

Method __process_embed_chunk

fastchat/serve/model_worker.py:151–176  ·  view source on GitHub ↗
(self, input_ids, attention_mask, **model_type_dict)

Source from the content-addressed store, hash-verified

149 return json.loads(x[:-1].decode())
150
151 def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
152 if model_type_dict.get("is_bert"):
153 model_output = self.model(input_ids)
154 if model_type_dict.get("is_robert"):
155 data = model_output.last_hidden_state
156 else:
157 data = model_output[0]
158 elif model_type_dict.get("is_t5"):
159 model_output = self.model(input_ids, decoder_input_ids=input_ids)
160 data = model_output.encoder_last_hidden_state
161 else:
162 model_output = self.model(input_ids, output_hidden_states=True)
163 if model_type_dict.get("is_chatglm"):
164 data = model_output.hidden_states[-1].transpose(0, 1)
165 else:
166 data = model_output.hidden_states[-1]
167
168 if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling:
169 sum_embeddings = data[:, 0]
170 else:
171 mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
172 masked_embeddings = data * mask
173 sum_embeddings = torch.sum(masked_embeddings, dim=1)
174 token_num = torch.sum(attention_mask).item()
175
176 return sum_embeddings, token_num
177
178 def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
179 embeddings = embeddings.cpu()

Callers 1

get_embeddingsMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected