Add an operation to fill a 1D tensor. The tensor is filled with the values between start and end with a step of 1 between the different elements. In pseudo-code, it corresponds to a tensor populated with the values: output = Tensor([dtype(ii) for ii in range(start, end, 1)
(start: Union[Tensor, int], end: Union[Tensor, int],
dtype: str)
| 1496 | |
| 1497 | # TODO: support step. |
| 1498 | def arange(start: Union[Tensor, int], end: Union[Tensor, int], |
| 1499 | dtype: str) -> Tensor: |
| 1500 | ''' |
| 1501 | Add an operation to fill a 1D tensor. |
| 1502 | |
| 1503 | The tensor is filled with the values between start and end with a step of 1 |
| 1504 | between the different elements. In pseudo-code, it corresponds to a tensor |
| 1505 | populated with the values: |
| 1506 | |
| 1507 | output = Tensor([dtype(ii) for ii in range(start, end, 1)]) |
| 1508 | |
| 1509 | For example, a call to arange(3, 6, 'int32') will add an operation to the |
| 1510 | TensorRT graph that will produce [3, 4, 5] when executed. The call to |
| 1511 | arange(2, 5, 'float32') will add a layer to generate [2.0, 3.0, 4.0]. |
| 1512 | |
| 1513 | This operation is implemented using a tensorrt.IFillLayer in |
| 1514 | trt.FillOperation.LINSPACE mode. |
| 1515 | |
| 1516 | Parameters: |
| 1517 | start : Union[Tensor, int] |
| 1518 | The starting point of the range. |
| 1519 | |
| 1520 | end : Union[Tensor, int] |
| 1521 | The end point of the range. |
| 1522 | |
| 1523 | dtype : str |
| 1524 | The type of the elements. See _str_to_trt_dtype_dict in _utils.py |
| 1525 | for a list of supported types and type names. |
| 1526 | |
| 1527 | Returns: |
| 1528 | The tensor produced by the fill layer. It is a 1D tensor containing |
| 1529 | `end-start` elements of type `dtype`. |
| 1530 | ''' |
| 1531 | res_dtype = str_dtype_to_trt(dtype) |
| 1532 | if isinstance(start, int): |
| 1533 | assert isinstance(end, int) |
| 1534 | array_func = int32_array if res_dtype == trt.int32 else int64_array |
| 1535 | start = constant(array_func(start)) |
| 1536 | end = constant(array_func(end)) |
| 1537 | elif isinstance(start, Tensor): |
| 1538 | assert isinstance(end, Tensor) |
| 1539 | assert start.dtype == trt.int32 or start.dtype == trt.int64 |
| 1540 | assert end.dtype == trt.int32 or end.dtype == trt.int64 |
| 1541 | if start.dtype != end.dtype: |
| 1542 | if start.dtype == trt.int32: # end == trt.int64 |
| 1543 | if res_dtype == trt.int32: |
| 1544 | end = cast(end, "int32") |
| 1545 | else: |
| 1546 | start = cast(start, "int64") |
| 1547 | else: # start == trt.int64 and end == trt.int32 |
| 1548 | if res_dtype == trt.int32: |
| 1549 | start = cast(start, "int32") |
| 1550 | else: |
| 1551 | end = cast(end, "int64") |
| 1552 | else: |
| 1553 | raise TypeError("%s is not supported" % type(start)) |
| 1554 | |
| 1555 | assert start.dtype == end.dtype, f"start type ({start.dtype}) != end type ({end.dtype})" |
no test coverage detected