Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. del_metadata = False is useful when this function is called before p2p communication.
(inp: Any, del_metadata=True)
| 325 | |
| 326 | |
| 327 | def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: |
| 328 | """ |
| 329 | Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. |
| 330 | del_metadata = False is useful when this function is called before p2p communication. |
| 331 | """ |
| 332 | if inp is None: |
| 333 | return |
| 334 | if type(inp) == torch.Tensor: |
| 335 | return |
| 336 | |
| 337 | assert "hidden_states" in inp, "required by pipeline parallelism." |
| 338 | inp_tensor = inp["hidden_states"] |
| 339 | scale = inp["fp8_scale"] |
| 340 | |
| 341 | fp8_view_type = inp_tensor.dtype |
| 342 | if fp8_view_type == torch.float16: |
| 343 | fp8_type = torch.float8_e5m2 |
| 344 | elif fp8_view_type == torch.bfloat16: |
| 345 | fp8_type = torch.float8_e4m3fn |
| 346 | else: |
| 347 | raise TypeError("Only float16, bfloat16 are implemented.") |
| 348 | |
| 349 | inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale |
| 350 | |
| 351 | if del_metadata: |
| 352 | del inp["fp8_scale"] |
| 353 | del inp["dtype"] |
| 354 | |
| 355 | |
| 356 | def _reduce_scatter_fp8( |
searching dependent graphs…