Repeats elements of a tensor along an axis. Parameters: repeats : int The number of repetitions along axis specified. dim : int The dimension along which repetitions are performed. Returns: A tensor with the same shape as input except fo
(tensor: Tensor, repeats: int, dim: int)
| 6312 | |
| 6313 | |
| 6314 | def repeat_interleave(tensor: Tensor, repeats: int, dim: int) -> Tensor: |
| 6315 | ''' |
| 6316 | Repeats elements of a tensor along an axis. |
| 6317 | |
| 6318 | Parameters: |
| 6319 | repeats : int |
| 6320 | The number of repetitions along axis specified. |
| 6321 | dim : int |
| 6322 | The dimension along which repetitions are performed. |
| 6323 | |
| 6324 | Returns: |
| 6325 | A tensor with the same shape as input except for repeated elements along specified dim. |
| 6326 | |
| 6327 | TODO: Allow repeats to be a list of integers and dim to be unspecified. |
| 6328 | ''' |
| 6329 | expanded_tensor = expand_dims(tensor, dim + 1) |
| 6330 | tile_output_size = concat([ |
| 6331 | repeats if i == (dim + 1) else shape(expanded_tensor, i) |
| 6332 | for i in range(expanded_tensor.ndim()) |
| 6333 | ]) |
| 6334 | tile = expand(expanded_tensor, tile_output_size) |
| 6335 | tile_reshape_size = [shape(tensor, i) for i in range(tensor.ndim())] |
| 6336 | tile_reshape_size[dim] = tile_reshape_size[dim] * repeats |
| 6337 | tensor = tile.view(concat(tile_reshape_size)) |
| 6338 | return tensor |
| 6339 | |
| 6340 | |
| 6341 | def meshgrid2d(x: Tensor, y: Tensor) -> Tuple[Tensor]: |
no test coverage detected