Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. The activations tensor is indexed by 'hidden_states' in the inp dict. After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved. Metadata
(inp: Any)
| 283 | |
| 284 | |
| 285 | def cast_to_fp8_pipeline(inp: Any) -> None: |
| 286 | """ |
| 287 | Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. |
| 288 | The activations tensor is indexed by 'hidden_states' in the inp dict. |
| 289 | After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved. |
| 290 | Metadata such as fp8_scale is saved into inp dict for communication. |
| 291 | """ |
| 292 | if inp is None: |
| 293 | return |
| 294 | # In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted. |
| 295 | if type(inp) == torch.Tensor: |
| 296 | return |
| 297 | |
| 298 | assert "hidden_states" in inp, "required by pipeline parallelism." |
| 299 | assert ( |
| 300 | inp["hidden_states"].size(-1) % 2 == 0 |
| 301 | ), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16" |
| 302 | inp_tensor = inp["hidden_states"] |
| 303 | inp_dtype = inp_tensor.dtype |
| 304 | |
| 305 | min_val, max_val = inp_tensor.aminmax() |
| 306 | amax = torch.maximum(min_val.abs(), max_val.abs()) |
| 307 | |
| 308 | finfo = torch.finfo(torch.float8_e4m3fn) |
| 309 | if amax > finfo.max: |
| 310 | fp8_type = torch.float8_e5m2 |
| 311 | fp8_view_type = torch.float16 |
| 312 | else: |
| 313 | fp8_type = torch.float8_e4m3fn |
| 314 | fp8_view_type = torch.bfloat16 |
| 315 | |
| 316 | finfo = torch.finfo(fp8_type) |
| 317 | scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float() |
| 318 | q_tensor = inp_tensor.data.float() * scale |
| 319 | # Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'. |
| 320 | # inp_tensor needs to be a float datatype to avoid error during gradient placement. |
| 321 | inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) |
| 322 | |
| 323 | inp["fp8_scale"] = scale.float().reciprocal() |
| 324 | inp["dtype"] = torch.zeros_like(scale).to(inp_dtype) |
| 325 | |
| 326 | |
| 327 | def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: |
searching dependent graphs…