MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / arange

Function arange

tensorrt_llm/functional.py:1498–1569  ·  view source on GitHub ↗

Add an operation to fill a 1D tensor. The tensor is filled with the values between start and end with a step of 1 between the different elements. In pseudo-code, it corresponds to a tensor populated with the values: output = Tensor([dtype(ii) for ii in range(start, end, 1)

(start: Union[Tensor, int], end: Union[Tensor, int],
           dtype: str)

Source from the content-addressed store, hash-verified

1496
1497# TODO: support step.
1498def arange(start: Union[Tensor, int], end: Union[Tensor, int],
1499 dtype: str) -> Tensor:
1500 '''
1501 Add an operation to fill a 1D tensor.
1502
1503 The tensor is filled with the values between start and end with a step of 1
1504 between the different elements. In pseudo-code, it corresponds to a tensor
1505 populated with the values:
1506
1507 output = Tensor([dtype(ii) for ii in range(start, end, 1)])
1508
1509 For example, a call to arange(3, 6, 'int32') will add an operation to the
1510 TensorRT graph that will produce [3, 4, 5] when executed. The call to
1511 arange(2, 5, 'float32') will add a layer to generate [2.0, 3.0, 4.0].
1512
1513 This operation is implemented using a tensorrt.IFillLayer in
1514 trt.FillOperation.LINSPACE mode.
1515
1516 Parameters:
1517 start : Union[Tensor, int]
1518 The starting point of the range.
1519
1520 end : Union[Tensor, int]
1521 The end point of the range.
1522
1523 dtype : str
1524 The type of the elements. See _str_to_trt_dtype_dict in _utils.py
1525 for a list of supported types and type names.
1526
1527 Returns:
1528 The tensor produced by the fill layer. It is a 1D tensor containing
1529 `end-start` elements of type `dtype`.
1530 '''
1531 res_dtype = str_dtype_to_trt(dtype)
1532 if isinstance(start, int):
1533 assert isinstance(end, int)
1534 array_func = int32_array if res_dtype == trt.int32 else int64_array
1535 start = constant(array_func(start))
1536 end = constant(array_func(end))
1537 elif isinstance(start, Tensor):
1538 assert isinstance(end, Tensor)
1539 assert start.dtype == trt.int32 or start.dtype == trt.int64
1540 assert end.dtype == trt.int32 or end.dtype == trt.int64
1541 if start.dtype != end.dtype:
1542 if start.dtype == trt.int32: # end == trt.int64
1543 if res_dtype == trt.int32:
1544 end = cast(end, "int32")
1545 else:
1546 start = cast(start, "int64")
1547 else: # start == trt.int64 and end == trt.int32
1548 if res_dtype == trt.int32:
1549 start = cast(start, "int32")
1550 else:
1551 end = cast(end, "int64")
1552 else:
1553 raise TypeError("%s is not supported" % type(start))
1554
1555 assert start.dtype == end.dtype, f"start type ({start.dtype}) != end type ({end.dtype})"

Callers 15

cumsumFunction · 0.85
gegeluFunction · 0.85
generate_alibi_biasesFunction · 0.85
make_causal_maskFunction · 0.85
compute_relative_biasFunction · 0.85
get_2d_sincos_pos_embedFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 9

str_dtype_to_trtFunction · 0.85
constantFunction · 0.85
castFunction · 0.85
constant_to_tensor_Function · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
castMethod · 0.80
viewMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected