MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / stack

Function stack

tensorrt_llm/functional.py:1961–2007  ·  view source on GitHub ↗

Add an operation to contact input tensors along a new dimension. The function creates an operation that creates a new dim for all the input tensors and then concatenates them along that new dim. . All the tensors in 'inputs' must have the same shape. for ii in range(input

(inputs: Sequence[Tensor], dim: int = 0)

Source from the content-addressed store, hash-verified

1959
1960
1961def stack(inputs: Sequence[Tensor], dim: int = 0) -> Tensor:
1962 '''
1963 Add an operation to contact input tensors along a new dimension.
1964
1965 The function creates an operation that creates a new dim for all the
1966 input tensors and then concatenates them along that new dim.
1967.
1968
1969 All the tensors in 'inputs' must have the same shape.
1970
1971 for ii in range(inputs[0].rank()):
1972 assert all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)
1973
1974 The shape of the output tensor is defined as:
1975
1976 output.rank() = inputs[0].rank() + 1
1977
1978 output.shape[dim] = len(inputs)
1979
1980 for ii in range(inputs[0].rank()):
1981 if ii < dim:
1982 output.shape[ii] = inputs[0].shape[ii]
1983 else:
1984 output.shape[ii+1] = inputs[0].shape[ii]
1985
1986 For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and
1987 [[4, 5], [6, 7]] both of shape [2, 2],
1988
1989 stack(inputs, 0)
1990
1991 will produce [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] of shape [2, 2, 2] and
1992
1993 stack(inputs, 1)
1994
1995 will produce [[[0, 1], [4, 5]], [[2, 3], [6, 7]]] of shape [2, 2, 2].
1996
1997 Parameters:
1998 inputs : Sequence[Tensor]
1999 The sequence of tensors to stack.
2000
2001 dim : int
2002 The dimension in which the stack is performed.
2003
2004 Returns:
2005 A tensor that contains the input tensors stacked along a new dimension.
2006 ''&#x27;
2007 return concat([unsqueeze(inp, axis=dim) for inp in inputs], dim=dim)
2008
2009
2010def expand_dims_like(left: Union[Tensor, int, float], right: Tensor) -> Tensor:

Callers 7

_process_gen_logitsFunction · 0.90
conv1dFunction · 0.85
forwardMethod · 0.85
rotate_halfMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 2

concatFunction · 0.85
unsqueezeFunction · 0.85

Tested by

no test coverage detected