Wraps XLA's `Reshape `_ operator. For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` / ``lax.expand_dims``. These preserve information about axis identity that may be useful for advanced transformation rules.
(operand: ArrayLike, new_sizes: Shape,
dimensions: Sequence[int] | None = None,
*, out_sharding: NamedSharding | P | None = None)
| 2970 | |
| 2971 | |
| 2972 | def reshape(operand: ArrayLike, new_sizes: Shape, |
| 2973 | dimensions: Sequence[int] | None = None, |
| 2974 | *, out_sharding: NamedSharding | P | None = None) -> Array: |
| 2975 | """Wraps XLA's `Reshape |
| 2976 | <https://www.openxla.org/xla/operation_semantics#reshape>`_ |
| 2977 | operator. |
| 2978 | |
| 2979 | For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` / |
| 2980 | ``lax.expand_dims``. These preserve information about axis identity that may |
| 2981 | be useful for advanced transformation rules. |
| 2982 | |
| 2983 | Args: |
| 2984 | operand: array to be reshaped. |
| 2985 | new_sizes: sequence of integers specifying the resulting shape. The size |
| 2986 | of the final array must match the size of the input. |
| 2987 | dimensions: optional sequence of integers specifying the permutation order of |
| 2988 | the input shape. If specified, the length must match ``operand.shape``. |
| 2989 | |
| 2990 | Returns: |
| 2991 | out: reshaped array. |
| 2992 | |
| 2993 | Examples: |
| 2994 | Simple reshaping from one to two dimensions: |
| 2995 | |
| 2996 | >>> x = jnp.arange(6) |
| 2997 | >>> y = reshape(x, (2, 3)) |
| 2998 | >>> y |
| 2999 | Array([[0, 1, 2], |
| 3000 | [3, 4, 5]], dtype=int32) |
| 3001 | |
| 3002 | Reshaping back to one dimension: |
| 3003 | |
| 3004 | >>> reshape(y, (6,)) |
| 3005 | Array([0, 1, 2, 3, 4, 5], dtype=int32) |
| 3006 | |
| 3007 | Reshaping to one dimension with permutation of dimensions: |
| 3008 | |
| 3009 | >>> reshape(y, (6,), (1, 0)) |
| 3010 | Array([0, 3, 1, 4, 2, 5], dtype=int32) |
| 3011 | """ |
| 3012 | new_sizes = canonicalize_shape(new_sizes) |
| 3013 | new_sizes = tuple(new_sizes) |
| 3014 | same_shape = core.definitely_equal_shape(np.shape(operand), new_sizes) |
| 3015 | if dimensions is None: |
| 3016 | same_dims = True |
| 3017 | dims = None |
| 3018 | else: |
| 3019 | dims = api_util._ensure_index_tuple(dimensions) |
| 3020 | same_dims = tuple(dims) == tuple(range(np.ndim(operand))) |
| 3021 | out_sharding = canonicalize_sharding(out_sharding, 'reshape') |
| 3022 | same_sharding = (out_sharding is None or |
| 3023 | typeof(operand).sharding == out_sharding) |
| 3024 | |
| 3025 | if (np.shape(operand) and same_shape and same_dims and same_sharding and |
| 3026 | isinstance(operand, Array)): |
| 3027 | return operand |
| 3028 | else: |
| 3029 | return reshape_p.bind( |
no test coverage detected
searching dependent graphs…