Appends dimensions to the end of a tensor until it has target_dims dimensions.
(x, target_dims)
| 190 | |
| 191 | |
| 192 | def append_dims(x, target_dims): |
| 193 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
| 194 | dims_to_append = target_dims - x.ndim |
| 195 | if dims_to_append < 0: |
| 196 | raise ValueError( |
| 197 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" |
| 198 | ) |
| 199 | return x[(...,) + (None,) * dims_to_append] |
| 200 | |
| 201 | |
| 202 | def load_model_from_config(config, ckpt, verbose=True, freeze=True): |
no outgoing calls
no test coverage detected