(tensor: torch.Tensor, batch_size: int)
| 446 | |
| 447 | |
| 448 | def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): |
| 449 | if tensor is None: |
| 450 | return None |
| 451 | |
| 452 | first_dim = tensor.shape[0] |
| 453 | |
| 454 | if first_dim == batch_size: |
| 455 | return tensor |
| 456 | |
| 457 | if batch_size % first_dim != 0: |
| 458 | raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") |
| 459 | |
| 460 | repeat_times = batch_size // first_dim |
| 461 | |
| 462 | return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) |
| 463 | |
| 464 | |
| 465 | def dim5(x): |