MCPcopy
hub / github.com/hpcaitech/ColossalAI / cast_to_fp8_pipeline

Function cast_to_fp8_pipeline

colossalai/quantization/fp8.py:285–324  ·  view source on GitHub ↗

Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. The activations tensor is indexed by 'hidden_states' in the inp dict. After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved. Metadata

(inp: Any)

Source from the content-addressed store, hash-verified

283
284
285def cast_to_fp8_pipeline(inp: Any) -> None:
286 """
287 Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
288 The activations tensor is indexed by 'hidden_states' in the inp dict.
289 After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
290 Metadata such as fp8_scale is saved into inp dict for communication.
291 """
292 if inp is None:
293 return
294 # In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
295 if type(inp) == torch.Tensor:
296 return
297
298 assert "hidden_states" in inp, "required by pipeline parallelism."
299 assert (
300 inp["hidden_states"].size(-1) % 2 == 0
301 ), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
302 inp_tensor = inp["hidden_states"]
303 inp_dtype = inp_tensor.dtype
304
305 min_val, max_val = inp_tensor.aminmax()
306 amax = torch.maximum(min_val.abs(), max_val.abs())
307
308 finfo = torch.finfo(torch.float8_e4m3fn)
309 if amax > finfo.max:
310 fp8_type = torch.float8_e5m2
311 fp8_view_type = torch.float16
312 else:
313 fp8_type = torch.float8_e4m3fn
314 fp8_view_type = torch.bfloat16
315
316 finfo = torch.finfo(fp8_type)
317 scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
318 q_tensor = inp_tensor.data.float() * scale
319 # Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
320 # inp_tensor needs to be a float datatype to avoid error during gradient placement.
321 inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
322
323 inp["fp8_scale"] = scale.float().reciprocal()
324 inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
325
326
327def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:

Callers 9

send_forwardMethod · 0.90
send_backwardMethod · 0.90
send_forwardMethod · 0.90
send_backwardMethod · 0.90
test_fp8_castFunction · 0.90

Calls 2

sizeMethod · 0.45
toMethod · 0.45

Tested by 1

test_fp8_castFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…