MCPcopy
hub / github.com/hpcaitech/ColossalAI / cast_from_fp8_pipeline

Function cast_from_fp8_pipeline

colossalai/quantization/fp8.py:327–353  ·  view source on GitHub ↗

Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. del_metadata = False is useful when this function is called before p2p communication.

(inp: Any, del_metadata=True)

Source from the content-addressed store, hash-verified

325
326
327def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
328 """
329 Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
330 del_metadata = False is useful when this function is called before p2p communication.
331 """
332 if inp is None:
333 return
334 if type(inp) == torch.Tensor:
335 return
336
337 assert "hidden_states" in inp, "required by pipeline parallelism."
338 inp_tensor = inp["hidden_states"]
339 scale = inp["fp8_scale"]
340
341 fp8_view_type = inp_tensor.dtype
342 if fp8_view_type == torch.float16:
343 fp8_type = torch.float8_e5m2
344 elif fp8_view_type == torch.bfloat16:
345 fp8_type = torch.float8_e4m3fn
346 else:
347 raise TypeError("Only float16, bfloat16 are implemented.")
348
349 inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
350
351 if del_metadata:
352 del inp["fp8_scale"]
353 del inp["dtype"]
354
355
356def _reduce_scatter_fp8(

Callers 13

send_forwardMethod · 0.90
send_backwardMethod · 0.90
run_forward_onlyMethod · 0.90
run_forward_backwardMethod · 0.90
recv_forwardMethod · 0.90
recv_backwardMethod · 0.90
send_forwardMethod · 0.90
send_backwardMethod · 0.90

Calls 1

toMethod · 0.45

Tested by 1

test_fp8_castFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…