MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / broadcast_helper

Function broadcast_helper

tensorrt_llm/functional.py:2920–2958  ·  view source on GitHub ↗

Helper function to perform a broadcast. For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Parameters: left : Union[Ten

(left: Union[Tensor, int, float],
                     right: Union[Tensor, int, float])

Source from the content-addressed store, hash-verified

2918
2919
2920def broadcast_helper(left: Union[Tensor, int, float],
2921 right: Union[Tensor, int, float]) -> Tuple[Tensor, Tensor]:
2922 '''
2923 Helper function to perform a broadcast.
2924
2925 For each input, that function first creates a constant tensor if the input
2926 is an integer or a float. Then, if needed, it expands the smaller tensor to
2927 make sure its rank is the same as the larger one.
2928
2929 Parameters:
2930 left : Union[Tensor, int, float]
2931 The first input. If that input is an integer or a float, the
2932 function creates a constant tensor.
2933
2934 right : Union[Tensor, int, float]
2935 The second input. If that input is an integer or a float, the
2936 function creates a constant tensor.
2937
2938 Returns:
2939 A pair of tensors of same rank.
2940 '''
2941 if not default_net().strongly_typed:
2942 left = constant_to_tensor_(left)
2943 right = constant_to_tensor_(right)
2944 else:
2945 left = constant_to_tensor_(
2946 left, right.dtype if isinstance(right, Tensor) else None)
2947 right = constant_to_tensor_(right, left.dtype)
2948
2949 if left.rank() == right.rank():
2950 return (left, right)
2951
2952 if left.rank() < right.rank():
2953 left = expand_dims_like(left, right)
2954 return (left, right)
2955
2956 if left.rank() > right.rank():
2957 right = expand_dims_like(right, left)
2958 return (left, right)
2959
2960
2961def elementwise_binary(left: Union[Tensor, int,

Callers 6

__xor__Method · 0.85
matmulFunction · 0.85
masked_selectFunction · 0.85
masked_scatterFunction · 0.85
elementwise_binaryFunction · 0.85
layer_normFunction · 0.85

Calls 4

default_netFunction · 0.85
constant_to_tensor_Function · 0.85
expand_dims_likeFunction · 0.85
rankMethod · 0.45

Tested by

no test coverage detected