Returns common shape for a sequence of Tensors. The common shape is the smallest-rank shape to which all tensors are broadcastable. #### Example ```python import tensorflow as tf import tf_quant_finance as tff args = [tf.ones([1, 2], dtype=tf.float64), tf.constant([[True], [False]])
(
*args: Sequence[tf.Tensor],
name: Optional[str] = None)
| 61 | |
| 62 | |
| 63 | def common_shape( |
| 64 | *args: Sequence[tf.Tensor], |
| 65 | name: Optional[str] = None) -> Union[tf.TensorShape, tf.Tensor]: |
| 66 | """Returns common shape for a sequence of Tensors. |
| 67 | |
| 68 | The common shape is the smallest-rank shape to which all tensors are |
| 69 | broadcastable. |
| 70 | |
| 71 | #### Example |
| 72 | ```python |
| 73 | import tensorflow as tf |
| 74 | import tf_quant_finance as tff |
| 75 | |
| 76 | args = [tf.ones([1, 2], dtype=tf.float64), tf.constant([[True], [False]])] |
| 77 | tff.utils.common_shape(*args) |
| 78 | # Expected: [2, 2] |
| 79 | ``` |
| 80 | |
| 81 | Args: |
| 82 | *args: A sequence of `Tensor`s of compatible shapes and any `dtype`s. |
| 83 | name: Python string. The name to give to the ops created by this function. |
| 84 | Default value: `None` which maps to the default name |
| 85 | `broadcast_tensor_shapes`. |
| 86 | |
| 87 | Returns: |
| 88 | A common shape for the input `Tensor`s, which an instance of TensorShape, |
| 89 | if the input shapes are fully defined, or a `Tensor` for dynamically shaped |
| 90 | inputs. |
| 91 | |
| 92 | Raises: |
| 93 | ValueError: If inputs are of incompatible shapes. |
| 94 | """ |
| 95 | name = 'common_shape' if name is None else name |
| 96 | with tf.name_scope(name): |
| 97 | # Flag to decide whether input Tensors have fully defined shapes |
| 98 | is_fully_defined = True |
| 99 | if args: |
| 100 | for arg in args: |
| 101 | arg = tf.convert_to_tensor(arg) |
| 102 | is_fully_defined &= arg.shape.is_fully_defined() |
| 103 | if is_fully_defined: |
| 104 | output_shape = args[0].shape |
| 105 | for arg in args[1:]: |
| 106 | try: |
| 107 | output_shape = tf.broadcast_static_shape(output_shape, arg.shape) |
| 108 | except ValueError: |
| 109 | raise ValueError(f'Shapes of {args} are incompatible') |
| 110 | return output_shape |
| 111 | output_shape = tf.shape(args[0]) |
| 112 | for arg in args[1:]: |
| 113 | output_shape = tf.broadcast_dynamic_shape(output_shape, tf.shape(arg)) |
| 114 | return output_shape |
| 115 | |
| 116 | |
| 117 | def broadcast_tensors( |
no test coverage detected