| 483 | return fut_combined2 |
| 484 | |
| 485 | def decompress(fut): |
| 486 | tensor_list_single = fut.value().wait()[0].value()[0] |
| 487 | scale_list_single = fut.value().wait()[1].value()[0] |
| 488 | |
| 489 | tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0)) |
| 490 | scale_list = scale_list_single.chunk(world_size, dim=0) |
| 491 | |
| 492 | for i in range(world_size): |
| 493 | tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] |
| 494 | out = torch.cat(tensor_list, dim=0) |
| 495 | |
| 496 | input_tensor_size = input_tensor.numel() |
| 497 | input_shape = input_tensor.shape |
| 498 | out = out[:input_tensor_size] |
| 499 | |
| 500 | input_tensor.copy_(out.view(input_shape).to(input_type)) |
| 501 | return input_tensor |
| 502 | |
| 503 | return all_to_all_fut.then(sum_and_allgather).then(decompress) |
| 504 | |