Save topic embeddings, either safely (using safetensors) or using legacy pytorch.
(model, save_directory, serialization: str)
| 316 | |
| 317 | |
| 318 | def save_hf(model, save_directory, serialization: str): |
| 319 | """Save topic embeddings, either safely (using safetensors) or using legacy pytorch.""" |
| 320 | tensors = np.array(model.topic_embeddings_, dtype=np.float32) |
| 321 | |
| 322 | if serialization == "safetensors": |
| 323 | tensors = {"topic_embeddings": tensors} |
| 324 | save_safetensors(save_directory / HF_SAFE_WEIGHTS_NAME, tensors) |
| 325 | if serialization == "pytorch": |
| 326 | assert _has_torch, "`pip install pytorch` to save as bin" |
| 327 | tensors = {"topic_embeddings": torch.from_numpy(tensors)} |
| 328 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME) |
| 329 | |
| 330 | |
| 331 | def save_ctfidf(model, save_directory: str, serialization: str): |
nothing calls this directly
no test coverage detected