(rank, world_size, port)
| 9 | |
| 10 | |
| 11 | def check_padded_tensor(rank, world_size, port): |
| 12 | disable_existing_loggers() |
| 13 | launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
| 14 | original_tensor = torch.rand(32, 64).to("cuda") |
| 15 | |
| 16 | device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) |
| 17 | target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) |
| 18 | d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) |
| 19 | |
| 20 | padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0) |
| 21 | assert padded_tensor.dist_layout == d_tensor.dist_layout |
| 22 | |
| 23 | tensor_copy = padded_tensor.clone() |
| 24 | assert is_padded_tensor(tensor_copy) |
| 25 | assert is_distributed_tensor(tensor_copy) |
| 26 | |
| 27 | tensor_detached = padded_tensor.detach() |
| 28 | assert is_padded_tensor(tensor_detached) |
| 29 | assert is_distributed_tensor(tensor_detached) |
| 30 | |
| 31 | unpadded_tensor = to_unpadded_tensor(padded_tensor) |
| 32 | assert unpadded_tensor.shape == d_tensor.shape |
| 33 | assert is_distributed_tensor(unpadded_tensor) |
| 34 | |
| 35 | global_tensor = to_global(unpadded_tensor) |
| 36 | assert global_tensor.shape == original_tensor.shape |
| 37 | |
| 38 | |
| 39 | @rerun_if_address_is_in_use() |
nothing calls this directly
no test coverage detected
searching dependent graphs…