| 35 | |
| 36 | |
| 37 | def split_into_partitions(tensor, num_partitions, partition_dim, stride): |
| 38 | |
| 39 | per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions) |
| 40 | per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) |
| 41 | |
| 42 | partitions_list = torch.split( |
| 43 | tensor, per_partition_per_stride_size, dim=partition_dim |
| 44 | ) |
| 45 | |
| 46 | partitions = [] |
| 47 | for i in range(num_partitions): |
| 48 | partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim) |
| 49 | partitions.append(partition) |
| 50 | |
| 51 | return partitions |
| 52 | |
| 53 | |
| 54 | def merge_partitions(merged, partitions, partition_dim, stride): |