Get the max inference batch size for LLM models according to the device type. Args: device_type (str): the type of device
(device_type: str)
| 392 | |
| 393 | |
| 394 | def get_max_batch_size(device_type: str): |
| 395 | """Get the max inference batch size for LLM models according to the device |
| 396 | type. |
| 397 | |
| 398 | Args: |
| 399 | device_type (str): the type of device |
| 400 | """ |
| 401 | assert device_type in ['cuda', 'ascend', 'maca', 'camb'] |
| 402 | if device_type == 'cuda': |
| 403 | max_batch_size_map = {'a100': 384, 'a800': 384, 'h100': 1024, 'h800': 1024, 'l20y': 1024, 'h200': 1024} |
| 404 | import torch |
| 405 | device_name = torch.cuda.get_device_name(0).lower() |
| 406 | for name, size in max_batch_size_map.items(): |
| 407 | if name in device_name: |
| 408 | return size |
| 409 | # for devices that are not in `max_batch_size_map`, set |
| 410 | # the max_batch_size 128 |
| 411 | return 128 |
| 412 | elif device_type == 'ascend': |
| 413 | return 256 |
| 414 | elif device_type == 'maca': |
| 415 | return 256 |
| 416 | elif device_type == 'camb': |
| 417 | return 256 |
| 418 | |
| 419 | |
| 420 | def is_bf16_supported(device_type: str = 'cuda'): |
no test coverage detected