Save c-TF-IDF sparse matrix.
(model, save_directory: str, serialization: str)
| 329 | |
| 330 | |
| 331 | def save_ctfidf(model, save_directory: str, serialization: str): |
| 332 | """Save c-TF-IDF sparse matrix.""" |
| 333 | indptr = model.c_tf_idf_.indptr |
| 334 | indices = model.c_tf_idf_.indices |
| 335 | data = model.c_tf_idf_.data |
| 336 | shape = np.array(model.c_tf_idf_.shape) |
| 337 | diag = np.array(model.ctfidf_model._idf_diag.data) |
| 338 | |
| 339 | if serialization == "safetensors": |
| 340 | tensors = { |
| 341 | "indptr": indptr, |
| 342 | "indices": indices, |
| 343 | "data": data, |
| 344 | "shape": shape, |
| 345 | "diag": diag, |
| 346 | } |
| 347 | save_safetensors(save_directory / CTFIDF_SAFE_WEIGHTS_NAME, tensors) |
| 348 | if serialization == "pytorch": |
| 349 | assert _has_torch, "`pip install pytorch` to save as .bin" |
| 350 | tensors = { |
| 351 | "indptr": torch.from_numpy(indptr), |
| 352 | "indices": torch.from_numpy(indices), |
| 353 | "data": torch.from_numpy(data), |
| 354 | "shape": torch.from_numpy(shape), |
| 355 | "diag": torch.from_numpy(diag), |
| 356 | } |
| 357 | torch.save(tensors, save_directory / CTFIDF_WEIGHTS_NAME) |
| 358 | |
| 359 | |
| 360 | def save_ctfidf_config(model, path): |
nothing calls this directly
no test coverage detected