MCPcopy
hub / github.com/google/tf-quant-finance / common_shape

Function common_shape

tf_quant_finance/utils/shape_utils.py:63–114  ·  view source on GitHub ↗

Returns common shape for a sequence of Tensors. The common shape is the smallest-rank shape to which all tensors are broadcastable. #### Example ```python import tensorflow as tf import tf_quant_finance as tff args = [tf.ones([1, 2], dtype=tf.float64), tf.constant([[True], [False]])

(
    *args: Sequence[tf.Tensor],
    name: Optional[str] = None)

Source from the content-addressed store, hash-verified

61
62
63def common_shape(
64 *args: Sequence[tf.Tensor],
65 name: Optional[str] = None) -> Union[tf.TensorShape, tf.Tensor]:
66 """Returns common shape for a sequence of Tensors.
67
68 The common shape is the smallest-rank shape to which all tensors are
69 broadcastable.
70
71 #### Example
72 ```python
73 import tensorflow as tf
74 import tf_quant_finance as tff
75
76 args = [tf.ones([1, 2], dtype=tf.float64), tf.constant([[True], [False]])]
77 tff.utils.common_shape(*args)
78 # Expected: [2, 2]
79 ```
80
81 Args:
82 *args: A sequence of `Tensor`s of compatible shapes and any `dtype`s.
83 name: Python string. The name to give to the ops created by this function.
84 Default value: `None` which maps to the default name
85 `broadcast_tensor_shapes`.
86
87 Returns:
88 A common shape for the input `Tensor`s, which an instance of TensorShape,
89 if the input shapes are fully defined, or a `Tensor` for dynamically shaped
90 inputs.
91
92 Raises:
93 ValueError: If inputs are of incompatible shapes.
94 """
95 name = 'common_shape' if name is None else name
96 with tf.name_scope(name):
97 # Flag to decide whether input Tensors have fully defined shapes
98 is_fully_defined = True
99 if args:
100 for arg in args:
101 arg = tf.convert_to_tensor(arg)
102 is_fully_defined &= arg.shape.is_fully_defined()
103 if is_fully_defined:
104 output_shape = args[0].shape
105 for arg in args[1:]:
106 try:
107 output_shape = tf.broadcast_static_shape(output_shape, arg.shape)
108 except ValueError:
109 raise ValueError(f'Shapes of {args} are incompatible')
110 return output_shape
111 output_shape = tf.shape(args[0])
112 for arg in args[1:]:
113 output_shape = tf.broadcast_dynamic_shape(output_shape, tf.shape(arg))
114 return output_shape
115
116
117def broadcast_tensors(

Callers 2

broadcast_tensorsFunction · 0.85

Calls 1

shapeMethod · 0.45

Tested by

no test coverage detected