Cast an array to the arrow type that corresponds to the requested feature type. For custom features like [`Audio`] or [`Image`], it takes into account the "cast_storage" methods they defined to enable casting from other arrow types. Args: array (`pa.Array`): The PyAr
(
array: pa.Array, feature: "FeatureType", allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True
)
| 2009 | |
| 2010 | @_wrap_for_chunked_arrays |
| 2011 | def cast_array_to_feature( |
| 2012 | array: pa.Array, feature: "FeatureType", allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True |
| 2013 | ) -> pa.Array: |
| 2014 | """Cast an array to the arrow type that corresponds to the requested feature type. |
| 2015 | For custom features like [`Audio`] or [`Image`], it takes into account the "cast_storage" methods |
| 2016 | they defined to enable casting from other arrow types. |
| 2017 | |
| 2018 | Args: |
| 2019 | array (`pa.Array`): |
| 2020 | The PyArrow array to cast. |
| 2021 | feature (`datasets.features.FeatureType`): |
| 2022 | The target feature type. |
| 2023 | allow_primitive_to_str (`bool`, defaults to `True`): |
| 2024 | Whether to allow casting primitives to strings. |
| 2025 | Defaults to `True`. |
| 2026 | allow_decimal_to_str (`bool`, defaults to `True`): |
| 2027 | Whether to allow casting decimals to strings. |
| 2028 | Defaults to `True`. |
| 2029 | |
| 2030 | Raises: |
| 2031 | `pa.ArrowInvalidError`: if the arrow data casting fails |
| 2032 | `TypeError`: if the target type is not supported according, e.g. |
| 2033 | |
| 2034 | - if a field is missing |
| 2035 | - if casting from primitives and `allow_primitive_to_str` is `False` |
| 2036 | - if casting from decimals and `allow_decimal_to_str` is `False` |
| 2037 | |
| 2038 | Returns: |
| 2039 | array (`pyarrow.Array`): the casted array |
| 2040 | """ |
| 2041 | from .features.features import LargeList, List, get_nested_type |
| 2042 | |
| 2043 | _c = partial( |
| 2044 | cast_array_to_feature, |
| 2045 | allow_primitive_to_str=allow_primitive_to_str, |
| 2046 | allow_decimal_to_str=allow_decimal_to_str, |
| 2047 | ) |
| 2048 | |
| 2049 | if isinstance(array, pa.ExtensionArray): |
| 2050 | array = array.storage |
| 2051 | if hasattr(feature, "cast_storage"): |
| 2052 | return feature.cast_storage(array) |
| 2053 | |
| 2054 | if pa.types.is_struct(array.type): |
| 2055 | # feature must be a dict |
| 2056 | if isinstance(feature, dict) and (array_fields := {field.name for field in array.type}) <= set(feature): |
| 2057 | null_array = pa.array([None] * len(array)) |
| 2058 | arrays = [ |
| 2059 | _c(array.field(name) if name in array_fields else null_array, subfeature) |
| 2060 | for name, subfeature in feature.items() |
| 2061 | ] |
| 2062 | return pa.StructArray.from_arrays(arrays, names=list(feature), mask=array.is_null()) |
| 2063 | elif pa.types.is_list(array.type) or pa.types.is_large_list(array.type): |
| 2064 | # feature must be either List(subfeature) or LargeList(subfeature) |
| 2065 | if isinstance(feature, LargeList): |
| 2066 | casted_array_values = _c(array.values, feature.feature) |
| 2067 | if pa.types.is_large_list(array.type) and casted_array_values.type == array.values.type: |
| 2068 | # Both array and feature have equal large_list type and values (within the list) type |