Get the serialized engine file name. Args: model (str): Model name, e.g., bloom, gpt. dtype (str): Data type, e.g., float32, float16, bfloat16, tp_size (int): The size of tensor parallel. pp_size (int): The siz
(model: str, dtype: str, tp_size: int, pp_size: int,
rank: int)
| 39 | |
| 40 | |
| 41 | def get_engine_name(model: str, dtype: str, tp_size: int, pp_size: int, |
| 42 | rank: int) -> str: |
| 43 | """ |
| 44 | Get the serialized engine file name. |
| 45 | |
| 46 | Args: |
| 47 | model (str): |
| 48 | Model name, e.g., bloom, gpt. |
| 49 | dtype (str): |
| 50 | Data type, e.g., float32, float16, bfloat16, |
| 51 | tp_size (int): |
| 52 | The size of tensor parallel. |
| 53 | pp_size (int): |
| 54 | The size of pipeline parallel. |
| 55 | rank (int): |
| 56 | The rank id. |
| 57 | |
| 58 | Returns: |
| 59 | str: The serialized engine file name. |
| 60 | """ |
| 61 | if pp_size == 1: |
| 62 | return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) |
| 63 | return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size, |
| 64 | pp_size, rank) |
| 65 | |
| 66 | |
| 67 | def read_config(config_path: Path) -> Tuple[ModelConfig, dict]: |