Add an operation to concatenate tensors. The function creates an operation that concatenates the tensors from the sequence 'inputs'. The concatenation is done along the dimension 'dim'. All the tensors in 'inputs' must have the same shape expect for the dimension 'dim'.
(inputs: Sequence[Union[Tensor, int]], dim: int = 0)
| 2575 | |
| 2576 | |
| 2577 | def concat(inputs: Sequence[Union[Tensor, int]], dim: int = 0) -> Tensor: |
| 2578 | ''' |
| 2579 | Add an operation to concatenate tensors. |
| 2580 | |
| 2581 | The function creates an operation that concatenates the tensors from the |
| 2582 | sequence 'inputs'. The concatenation is done along the dimension 'dim'. |
| 2583 | |
| 2584 | All the tensors in 'inputs' must have the same shape expect for the |
| 2585 | dimension 'dim'. |
| 2586 | |
| 2587 | for ii in range(inputs[0].rank()): |
| 2588 | assert (ii == dim) or all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs) |
| 2589 | |
| 2590 | The shape of the output tensor is defined as: |
| 2591 | |
| 2592 | for ii in range(inputs[0].rank()): |
| 2593 | # Same size as all the inputs in dimension ii != dim. |
| 2594 | output.shape[ii] = inputs[0].shape[ii] |
| 2595 | |
| 2596 | # Sum of the sizes in the different inputs in dimension 'dim'. |
| 2597 | if ii == dim: |
| 2598 | for jj in range(1, len(inputs)): |
| 2599 | output.shape[ii] += inputs[jj].shape[ii] |
| 2600 | |
| 2601 | For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and |
| 2602 | [[4, 5], [6, 7]] both of shape [2, 2], |
| 2603 | |
| 2604 | concat(inputs, 0) |
| 2605 | |
| 2606 | will produce [[0, 1], [2, 3], [4, 5], [6, 7]] of shape [4, 2] and |
| 2607 | |
| 2608 | concat(inputs, 1) |
| 2609 | |
| 2610 | will produce [[0, 1, 4, 5], [2, 3, 6, 7]] of shape [2, 4]. |
| 2611 | |
| 2612 | Parameters: |
| 2613 | inputs : Sequence[Union[Tensor, int]] |
| 2614 | The sequence of tensors to concatenate. For integers, that function |
| 2615 | creates constant tensors. |
| 2616 | |
| 2617 | dim : int |
| 2618 | The dimension in which the concatenation is performed. |
| 2619 | |
| 2620 | Returns: |
| 2621 | A tensor that contains the concatenation of the tensors. |
| 2622 | ''' |
| 2623 | assert len( |
| 2624 | inputs |
| 2625 | ) > 0, f"Number of inputs ({len(inputs)}) to the concatenation layer must be > 0." |
| 2626 | tmp = [] |
| 2627 | inputs = constants_to_tensors_(*inputs) |
| 2628 | for i in inputs: |
| 2629 | if i.rank() == 0: |
| 2630 | tmp.append(i.view([1])) |
| 2631 | else: |
| 2632 | tmp.append(i) |
| 2633 | |
| 2634 | layer = default_trtnet().add_concatenation([i.trt_tensor for i in tmp]) |