()
| 49 | |
| 50 | |
| 51 | def get_accelerator(): |
| 52 | global ds_accelerator |
| 53 | if ds_accelerator is not None: |
| 54 | return ds_accelerator |
| 55 | |
| 56 | accelerator_name = None |
| 57 | ds_set_method = None |
| 58 | # 1. Detect whether there is override of DeepSpeed accelerators from environment variable. |
| 59 | if "DS_ACCELERATOR" in os.environ.keys(): |
| 60 | accelerator_name = os.environ["DS_ACCELERATOR"] |
| 61 | if accelerator_name == "xpu": |
| 62 | try: |
| 63 | import torch |
| 64 | assert hasattr(torch, 'xpu') and torch.xpu.is_available(), \ |
| 65 | "XPU_Accelerator requires PyTorch with XPU support (torch.xpu)." |
| 66 | except (ImportError, AssertionError) as e: |
| 67 | raise ValueError(f"XPU_Accelerator requires PyTorch with XPU support: {e}") |
| 68 | elif accelerator_name == "cpu": |
| 69 | pass |
| 70 | elif accelerator_name == "npu": |
| 71 | try: |
| 72 | import torch_npu # noqa: F401 # type: ignore |
| 73 | except ImportError as e: |
| 74 | raise ValueError("NPU_Accelerator requires torch_npu, which is not installed on this system.") |
| 75 | pass |
| 76 | elif accelerator_name == "sdaa": |
| 77 | try: |
| 78 | import torch_sdaa # noqa: F401 # type: ignore |
| 79 | except ImportError as e: |
| 80 | raise ValueError("SDAA_Accelerator requires torch_sdaa, which is not installed on this system.") |
| 81 | pass |
| 82 | elif accelerator_name == "mps": |
| 83 | try: |
| 84 | import torch.mps |
| 85 | |
| 86 | # should use torch.mps.is_available() if it exists someday but this is used as proxy |
| 87 | torch.mps.current_allocated_memory() |
| 88 | except (RuntimeError, ImportError) as e: |
| 89 | raise ValueError("MPS_Accelerator requires torch.mps, which is not installed on this system.") |
| 90 | elif accelerator_name == "hpu": |
| 91 | try: |
| 92 | import habana_frameworks.torch.hpu # noqa: F401 |
| 93 | except ImportError as e: |
| 94 | raise ValueError( |
| 95 | "HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") |
| 96 | elif accelerator_name == "mlu": |
| 97 | try: |
| 98 | import torch_mlu # noqa: F401 |
| 99 | except ImportError as e: |
| 100 | raise ValueError("MLU_Accelerator requires torch_mlu, which is not installed on this system.") |
| 101 | elif accelerator_name == "supa": |
| 102 | try: |
| 103 | import torch_supa # noqa: F401 # type: ignore |
| 104 | except ImportError as e: |
| 105 | raise ValueError("SUPA_Accelerator requires torch_supa, which is not installed on this system.") |
| 106 | elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST: |
| 107 | raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. ' |
| 108 | f'Value "{accelerator_name}" is not supported') |
searching dependent graphs…