MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_padded_tensor

Function check_padded_tensor

tests/test_tensor/test_padded_tensor.py:11–36  ·  view source on GitHub ↗
(rank, world_size, port)

Source from the content-addressed store, hash-verified

9
10
11def check_padded_tensor(rank, world_size, port):
12 disable_existing_loggers()
13 launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
14 original_tensor = torch.rand(32, 64).to("cuda")
15
16 device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
17 target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
18 d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
19
20 padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
21 assert padded_tensor.dist_layout == d_tensor.dist_layout
22
23 tensor_copy = padded_tensor.clone()
24 assert is_padded_tensor(tensor_copy)
25 assert is_distributed_tensor(tensor_copy)
26
27 tensor_detached = padded_tensor.detach()
28 assert is_padded_tensor(tensor_detached)
29 assert is_distributed_tensor(tensor_detached)
30
31 unpadded_tensor = to_unpadded_tensor(padded_tensor)
32 assert unpadded_tensor.shape == d_tensor.shape
33 assert is_distributed_tensor(unpadded_tensor)
34
35 global_tensor = to_global(unpadded_tensor)
36 assert global_tensor.shape == original_tensor.shape
37
38
39@rerun_if_address_is_in_use()

Callers

nothing calls this directly

Calls 13

disable_existing_loggersFunction · 0.90
launchFunction · 0.90
DeviceMeshClass · 0.90
ShardingSpecClass · 0.90
distribute_tensorFunction · 0.90
to_padded_tensorFunction · 0.90
is_padded_tensorFunction · 0.90
is_distributed_tensorFunction · 0.90
to_unpadded_tensorFunction · 0.90
to_globalFunction · 0.90
detachMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…