(self, params)
| 183 | |
| 184 | @torch.inference_mode() |
| 185 | def get_embeddings(self, params): |
| 186 | self.call_ct += 1 |
| 187 | |
| 188 | try: |
| 189 | tokenizer = self.tokenizer |
| 190 | ret = {"embedding": [], "token_num": 0} |
| 191 | |
| 192 | model_type_dict = { |
| 193 | "is_llama": "llama" in str(type(self.model)), |
| 194 | "is_t5": "t5" in str(type(self.model)), |
| 195 | "is_chatglm": "chatglm" in str(type(self.model)), |
| 196 | "is_bert": "bert" in str(type(self.model)), |
| 197 | "is_robert": "robert" in str(type(self.model)), |
| 198 | } |
| 199 | |
| 200 | if self.embed_in_truncate: |
| 201 | encoding = tokenizer.batch_encode_plus( |
| 202 | params["input"], |
| 203 | padding=True, |
| 204 | truncation="longest_first", |
| 205 | return_tensors="pt", |
| 206 | max_length=self.context_len, |
| 207 | ) |
| 208 | else: |
| 209 | encoding = tokenizer.batch_encode_plus( |
| 210 | params["input"], padding=True, return_tensors="pt" |
| 211 | ) |
| 212 | input_ids = encoding["input_ids"].to(self.device) |
| 213 | attention_mask = input_ids != tokenizer.pad_token_id |
| 214 | |
| 215 | base64_encode = params.get("encoding_format", None) |
| 216 | |
| 217 | if self.embed_in_truncate: |
| 218 | embedding, token_num = self.__process_embed_chunk( |
| 219 | input_ids, attention_mask, **model_type_dict |
| 220 | ) |
| 221 | if ( |
| 222 | not hasattr(self.model, "use_cls_pooling") |
| 223 | or not self.model.use_cls_pooling |
| 224 | ): |
| 225 | embedding = embedding / token_num |
| 226 | normalized_embeddings = F.normalize(embedding, p=2, dim=1) |
| 227 | ret["token_num"] = token_num |
| 228 | else: |
| 229 | all_embeddings = [] |
| 230 | all_token_num = 0 |
| 231 | for i in range(0, input_ids.size(1), self.context_len): |
| 232 | chunk_input_ids = input_ids[:, i : i + self.context_len] |
| 233 | chunk_attention_mask = attention_mask[:, i : i + self.context_len] |
| 234 | |
| 235 | # add cls token and mask to get cls embedding |
| 236 | if ( |
| 237 | hasattr(self.model, "use_cls_pooling") |
| 238 | and self.model.use_cls_pooling |
| 239 | ): |
| 240 | cls_tokens = ( |
| 241 | torch.zeros( |
| 242 | (chunk_input_ids.size(0), 1), |
no test coverage detected