Appends dimensions to the end of a tensor until it has target_dims dimensions.
(x, target_dims)
| 8 | |
| 9 | |
| 10 | def append_dims(x, target_dims): |
| 11 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
| 12 | dims_to_append = target_dims - x.ndim |
| 13 | if dims_to_append < 0: |
| 14 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') |
| 15 | return x[(...,) + (None,) * dims_to_append] |
| 16 | |
| 17 | |
| 18 | def append_zero(x): |