Cast pytorch/tensorflow/pandas objects to python numpy array/lists. It works recursively. If `optimize_list_casting` is True, to avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted
(obj: Any, only_1d_for_numpy: bool, optimize_list_casting: bool)
| 275 | |
| 276 | |
| 277 | def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_casting: bool) -> tuple[Any, bool]: |
| 278 | """ |
| 279 | Cast pytorch/tensorflow/pandas objects to python numpy array/lists. |
| 280 | It works recursively. |
| 281 | |
| 282 | If `optimize_list_casting` is True, to avoid iterating over possibly long lists, it first checks (recursively) if the first element that is not None or empty (if it is a sequence) has to be casted. |
| 283 | If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same. |
| 284 | This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example. |
| 285 | |
| 286 | Args: |
| 287 | obj: the object (nested struct) to cast. |
| 288 | only_1d_for_numpy (bool): whether to keep the full multi-dim tensors as multi-dim numpy arrays, or convert them to |
| 289 | nested lists of 1-dimensional numpy arrays. This can be useful to keep only 1-d arrays to instantiate Arrow arrays. |
| 290 | Indeed Arrow only support converting 1-dimensional array values. |
| 291 | optimize_list_casting (bool): whether to optimize list casting by checking the first non-null element to see if it needs to be casted |
| 292 | and if it doesn't, not checking the rest of the list elements. |
| 293 | |
| 294 | Returns: |
| 295 | casted_obj: the casted object |
| 296 | has_changed (bool): True if the object has been changed, False if it is identical |
| 297 | """ |
| 298 | |
| 299 | if config.TF_AVAILABLE and "tensorflow" in sys.modules: |
| 300 | import tensorflow as tf |
| 301 | |
| 302 | if config.TORCH_AVAILABLE and "torch" in sys.modules: |
| 303 | import torch |
| 304 | |
| 305 | if config.JAX_AVAILABLE and "jax" in sys.modules: |
| 306 | import jax.numpy as jnp |
| 307 | |
| 308 | if config.PIL_AVAILABLE and "PIL" in sys.modules: |
| 309 | import PIL.Image |
| 310 | |
| 311 | if config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules: |
| 312 | import pdfplumber |
| 313 | |
| 314 | if config.NIBABEL_AVAILABLE and "nibabel" in sys.modules: |
| 315 | import nibabel as nib |
| 316 | |
| 317 | if config.TORCHCODEC_AVAILABLE and "torchcodec" in sys.modules: |
| 318 | from torchcodec.decoders import AudioDecoder, VideoDecoder |
| 319 | |
| 320 | if isinstance(obj, np.ndarray): |
| 321 | if obj.ndim == 0: |
| 322 | return obj[()], True |
| 323 | elif not only_1d_for_numpy or obj.ndim == 1: |
| 324 | return obj, False |
| 325 | else: |
| 326 | return ( |
| 327 | [ |
| 328 | _cast_to_python_objects( |
| 329 | x, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting |
| 330 | )[0] |
| 331 | for x in obj |
| 332 | ], |
| 333 | True, |
| 334 | ) |
no test coverage detected