MCPcopy
hub / github.com/jax-ml/jax / reshape

Function reshape

jax/_src/lax/lax.py:2972–3032  ·  view source on GitHub ↗

Wraps XLA's `Reshape `_ operator. For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` / ``lax.expand_dims``. These preserve information about axis identity that may be useful for advanced transformation rules.

(operand: ArrayLike, new_sizes: Shape,
            dimensions: Sequence[int] | None = None,
            *, out_sharding: NamedSharding | P | None = None)

Source from the content-addressed store, hash-verified

2970
2971
2972def reshape(operand: ArrayLike, new_sizes: Shape,
2973 dimensions: Sequence[int] | None = None,
2974 *, out_sharding: NamedSharding | P | None = None) -> Array:
2975 """Wraps XLA's `Reshape
2976 <https://www.openxla.org/xla/operation_semantics#reshape>`_
2977 operator.
2978
2979 For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
2980 ``lax.expand_dims``. These preserve information about axis identity that may
2981 be useful for advanced transformation rules.
2982
2983 Args:
2984 operand: array to be reshaped.
2985 new_sizes: sequence of integers specifying the resulting shape. The size
2986 of the final array must match the size of the input.
2987 dimensions: optional sequence of integers specifying the permutation order of
2988 the input shape. If specified, the length must match ``operand.shape``.
2989
2990 Returns:
2991 out: reshaped array.
2992
2993 Examples:
2994 Simple reshaping from one to two dimensions:
2995
2996 >>> x = jnp.arange(6)
2997 >>> y = reshape(x, (2, 3))
2998 >>> y
2999 Array([[0, 1, 2],
3000 [3, 4, 5]], dtype=int32)
3001
3002 Reshaping back to one dimension:
3003
3004 >>> reshape(y, (6,))
3005 Array([0, 1, 2, 3, 4, 5], dtype=int32)
3006
3007 Reshaping to one dimension with permutation of dimensions:
3008
3009 >>> reshape(y, (6,), (1, 0))
3010 Array([0, 3, 1, 4, 2, 5], dtype=int32)
3011 """
3012 new_sizes = canonicalize_shape(new_sizes)
3013 new_sizes = tuple(new_sizes)
3014 same_shape = core.definitely_equal_shape(np.shape(operand), new_sizes)
3015 if dimensions is None:
3016 same_dims = True
3017 dims = None
3018 else:
3019 dims = api_util._ensure_index_tuple(dimensions)
3020 same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
3021 out_sharding = canonicalize_sharding(out_sharding, 'reshape')
3022 same_sharding = (out_sharding is None or
3023 typeof(operand).sharding == out_sharding)
3024
3025 if (np.shape(operand) and same_shape and same_dims and same_sharding and
3026 isinstance(operand, Array)):
3027 return operand
3028 else:
3029 return reshape_p.bind(

Callers 9

collapseFunction · 0.70
_unbroadcastFunction · 0.70
_maybe_broadcastFunction · 0.70
_tile_transpose_ruleFunction · 0.70
_reshape_transpose_ruleFunction · 0.70
_reshape_batch_ruleFunction · 0.70
_reduce_jvpFunction · 0.70
_reduce_chooser_jvp_ruleFunction · 0.70
_top_k_jvpFunction · 0.70

Calls 6

canonicalize_shapeFunction · 0.90
canonicalize_shardingFunction · 0.90
typeofFunction · 0.90
bindMethod · 0.80
shapeMethod · 0.45
ndimMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…