MCPcopy
hub / github.com/hpcaitech/ColossalAI / decompress

Function decompress

colossalai/quantization/fp8.py:485–501  ·  view source on GitHub ↗
(fut)

Source from the content-addressed store, hash-verified

483 return fut_combined2
484
485 def decompress(fut):
486 tensor_list_single = fut.value().wait()[0].value()[0]
487 scale_list_single = fut.value().wait()[1].value()[0]
488
489 tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
490 scale_list = scale_list_single.chunk(world_size, dim=0)
491
492 for i in range(world_size):
493 tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
494 out = torch.cat(tensor_list, dim=0)
495
496 input_tensor_size = input_tensor.numel()
497 input_shape = input_tensor.shape
498 out = out[:input_tensor_size]
499
500 input_tensor.copy_(out.view(input_shape).to(input_type))
501 return input_tensor
502
503 return all_to_all_fut.then(sum_and_allgather).then(decompress)
504

Callers

nothing calls this directly

Calls 2

waitMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…