| 149 | return json.loads(x[:-1].decode()) |
| 150 | |
| 151 | def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): |
| 152 | if model_type_dict.get("is_bert"): |
| 153 | model_output = self.model(input_ids) |
| 154 | if model_type_dict.get("is_robert"): |
| 155 | data = model_output.last_hidden_state |
| 156 | else: |
| 157 | data = model_output[0] |
| 158 | elif model_type_dict.get("is_t5"): |
| 159 | model_output = self.model(input_ids, decoder_input_ids=input_ids) |
| 160 | data = model_output.encoder_last_hidden_state |
| 161 | else: |
| 162 | model_output = self.model(input_ids, output_hidden_states=True) |
| 163 | if model_type_dict.get("is_chatglm"): |
| 164 | data = model_output.hidden_states[-1].transpose(0, 1) |
| 165 | else: |
| 166 | data = model_output.hidden_states[-1] |
| 167 | |
| 168 | if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling: |
| 169 | sum_embeddings = data[:, 0] |
| 170 | else: |
| 171 | mask = attention_mask.unsqueeze(-1).expand(data.size()).float() |
| 172 | masked_embeddings = data * mask |
| 173 | sum_embeddings = torch.sum(masked_embeddings, dim=1) |
| 174 | token_num = torch.sum(attention_mask).item() |
| 175 | |
| 176 | return sum_embeddings, token_num |
| 177 | |
| 178 | def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: |
| 179 | embeddings = embeddings.cpu() |