MCPcopy
hub / github.com/EleutherAI/gpt-neox / split_into_partitions

Function split_into_partitions

tools/merge_mp_partitions.py:37–51  ·  view source on GitHub ↗
(tensor, num_partitions, partition_dim, stride)

Source from the content-addressed store, hash-verified

35
36
37def split_into_partitions(tensor, num_partitions, partition_dim, stride):
38
39 per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions)
40 per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
41
42 partitions_list = torch.split(
43 tensor, per_partition_per_stride_size, dim=partition_dim
44 )
45
46 partitions = []
47 for i in range(num_partitions):
48 partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim)
49 partitions.append(partition)
50
51 return partitions
52
53
54def merge_partitions(merged, partitions, partition_dim, stride):

Callers 1

test_split_mergeFunction · 0.85

Calls 1

sizeMethod · 0.80

Tested by 1

test_split_mergeFunction · 0.68