Helper function to perform a broadcast. For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Parameters: left : Union[Ten
(left: Union[Tensor, int, float],
right: Union[Tensor, int, float])
| 2918 | |
| 2919 | |
| 2920 | def broadcast_helper(left: Union[Tensor, int, float], |
| 2921 | right: Union[Tensor, int, float]) -> Tuple[Tensor, Tensor]: |
| 2922 | ''' |
| 2923 | Helper function to perform a broadcast. |
| 2924 | |
| 2925 | For each input, that function first creates a constant tensor if the input |
| 2926 | is an integer or a float. Then, if needed, it expands the smaller tensor to |
| 2927 | make sure its rank is the same as the larger one. |
| 2928 | |
| 2929 | Parameters: |
| 2930 | left : Union[Tensor, int, float] |
| 2931 | The first input. If that input is an integer or a float, the |
| 2932 | function creates a constant tensor. |
| 2933 | |
| 2934 | right : Union[Tensor, int, float] |
| 2935 | The second input. If that input is an integer or a float, the |
| 2936 | function creates a constant tensor. |
| 2937 | |
| 2938 | Returns: |
| 2939 | A pair of tensors of same rank. |
| 2940 | ''' |
| 2941 | if not default_net().strongly_typed: |
| 2942 | left = constant_to_tensor_(left) |
| 2943 | right = constant_to_tensor_(right) |
| 2944 | else: |
| 2945 | left = constant_to_tensor_( |
| 2946 | left, right.dtype if isinstance(right, Tensor) else None) |
| 2947 | right = constant_to_tensor_(right, left.dtype) |
| 2948 | |
| 2949 | if left.rank() == right.rank(): |
| 2950 | return (left, right) |
| 2951 | |
| 2952 | if left.rank() < right.rank(): |
| 2953 | left = expand_dims_like(left, right) |
| 2954 | return (left, right) |
| 2955 | |
| 2956 | if left.rank() > right.rank(): |
| 2957 | right = expand_dims_like(right, left) |
| 2958 | return (left, right) |
| 2959 | |
| 2960 | |
| 2961 | def elementwise_binary(left: Union[Tensor, int, |
no test coverage detected